From 2f4fd18151520dd67325c4667cc93a1191e925e1 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Mon, 26 Aug 2024 19:30:37 -0400 Subject: [PATCH 01/20] Extracted common interface between V1 and V2 --- polaris/benchmark/_base.py | 12 +- polaris/dataset/__init__.py | 8 +- polaris/dataset/_base.py | 402 ++++++++++++++++++++++++ polaris/dataset/_competition_dataset.py | 7 +- polaris/dataset/_dataset.py | 297 ++--------------- polaris/dataset/_factory.py | 10 +- polaris/dataset/_subset.py | 4 +- polaris/experimental/__init__.py | 0 polaris/experimental/_dataset_v2.py | 5 + polaris/hub/client.py | 21 +- polaris/loader/load.py | 6 +- polaris/mixins/_checksum.py | 5 +- polaris/utils/misc.py | 4 +- tests/conftest.py | 4 +- tests/test_dataset.py | 20 +- tests/test_dataset_v2.py | 0 tests/test_evaluate.py | 7 +- 17 files changed, 487 insertions(+), 325 deletions(-) create mode 100644 polaris/dataset/_base.py create mode 100644 polaris/experimental/__init__.py create mode 100644 polaris/experimental/_dataset_v2.py create mode 100644 tests/test_dataset_v2.py diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 45f9cc15..e2ec4902 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -1,6 +1,6 @@ -from itertools import chain import json from hashlib import md5 +from itertools import chain from typing import Any, Callable, Optional, Union import fsspec @@ -18,11 +18,11 @@ from sklearn.utils.multiclass import type_of_target from polaris._artifact import BaseArtifactModel -from polaris.mixins import ChecksumMixin -from polaris.dataset import Dataset, Subset, CompetitionDataset +from polaris.dataset import CompetitionDataset, DatasetV1, Subset from polaris.evaluate import BenchmarkResults, Metric from polaris.evaluate.utils import evaluate_benchmark from polaris.hub.settings import PolarisHubSettings +from polaris.mixins import ChecksumMixin from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidBenchmarkError from polaris.utils.misc import listit @@ -96,7 +96,7 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): # Public attributes # Data - dataset: Union[Dataset, CompetitionDataset, str, dict[str, Any]] + dataset: Union[DatasetV1, CompetitionDataset, str, dict[str, Any]] target_cols: ColumnsType input_cols: ColumnsType split: SplitType @@ -114,9 +114,9 @@ def _validate_dataset(cls, v): TODO (cwognum): Allow multiple datasets to be used as part of a benchmark """ if isinstance(v, dict): - v = Dataset(**v) + v = DatasetV1(**v) elif isinstance(v, str): - v = Dataset.from_json(v) + v = DatasetV1.from_json(v) return v @field_validator("target_cols", "input_cols") diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index 3f861d54..ab059de8 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -1,8 +1,10 @@ -from polaris.dataset._column import ColumnAnnotation, Modality, KnownContentType -from polaris.dataset._dataset import Dataset +from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality +from polaris.dataset._competition_dataset import CompetitionDataset +from polaris.dataset._dataset import DatasetV1 from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset -from polaris.dataset._competition_dataset import CompetitionDataset + +Dataset = DatasetV1 __all__ = [ "ColumnAnnotation", diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py new file mode 100644 index 00000000..c13c9655 --- /dev/null +++ b/polaris/dataset/_base.py @@ -0,0 +1,402 @@ +import abc +import json +import uuid +from pathlib import Path +from typing import Dict, List, MutableMapping, Optional, Union + +import fsspec +import numpy as np +import pandas as pd +import zarr +from datamol.utils import fs as dmfs +from loguru import logger +from pydantic import ( + Field, + PrivateAttr, + computed_field, + field_serializer, + field_validator, + model_validator, +) + +from polaris._artifact import BaseArtifactModel +from polaris.dataset._adapters import Adapter +from polaris.dataset._column import ColumnAnnotation +from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum +from polaris.dataset.zarr._utils import load_zarr_group_to_memory +from polaris.hub.polarisfs import PolarisFileSystem +from polaris.mixins import ChecksumMixin +from polaris.utils.constants import DEFAULT_CACHE_DIR +from polaris.utils.dict2html import dict2html +from polaris.utils.errors import InvalidDatasetError +from polaris.utils.types import ( + AccessType, + HttpUrlString, + HubOwner, + SupportedLicenseType, + ZarrConflictResolution, +) + +# Constants +_CACHE_SUBDIR = "datasets" + + +class BaseDataset(BaseArtifactModel, ChecksumMixin, abc.ABC): + """Base data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. + + At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. + A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple + [`BenchmarkSpecification`][polaris.benchmark.BenchmarkSpecification] objects. + + Info: Pointer columns + Whereas a `Dataset` contains all information required to construct a dataset, it is not ready yet. + For complex data, such as images, we support storing the content in external blobs of data. + In that case, the table contains _pointers_ to these blobs that are dynamically loaded when needed. + + Attributes: + default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data + for specific columns. + zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to. + readme: Markdown text that can be used to provide a formatted description of the dataset. + If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI + as it provides a rich text editor for writing markdown. + annotations: Each column _can be_ annotated with a [`ColumnAnnotation`][polaris.dataset.ColumnAnnotation] object. + Importantly, this is used to annotate whether a column is a pointer column. + source: The data source, e.g. a DOI, Github repo or URI. + license: The dataset license. Polaris only supports some Creative Commons licenses. See [`SupportedLicenseType`][polaris.utils.types.SupportedLicenseType] for accepted ID values. + curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. + For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. + + Raises: + InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. + """ + + # Public attributes + # Data + default_adapters: Dict[str, Adapter] = Field(default_factory=dict) + zarr_root_path: Optional[str] = None + + # Additional meta-data + readme: str = "" + annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict) + source: Optional[HttpUrlString] = None + license: Optional[SupportedLicenseType] = None + curation_reference: Optional[HttpUrlString] = None + + # Config + cache_dir: Optional[Path] = None # Where to cache the data to if cache() is called. + + # Private attributes + _zarr_root: Optional[zarr.Group] = PrivateAttr(None) + _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) + _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) + _client = PrivateAttr(None) # Optional[PolarisHubClient] + _warn_about_remote_zarr: bool = PrivateAttr(True) + + @model_validator(mode="after") + @classmethod + def _validate_model(cls, m: "BaseDataset"): + """Verifies some dependencies between properties""" + + # Set the default cache dir if none and make sure it exists + if m.cache_dir is None: + dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + + m.cache_dir.mkdir(parents=True, exist_ok=True) + return m + + @field_validator("default_adapters", mode="before") + def _validate_adapters(cls, value): + """Validate the adapters""" + return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} + + @field_serializer("default_adapters") + def _serialize_adapters(self, value: List[Adapter]): + """Serializes the adapters""" + return {k: v.name for k, v in value.items()} + + @field_serializer("cache_dir") + def _serialize_cache_dir(value): + """Serialize the cacha_dir""" + return str(value) + + @computed_field + @property + @abc.abstractmethod + def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: + """ + The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. + If the dataset doesn't use Zarr, this will simply return an empty list. + """ + raise NotImplementedError + + @property + def client(self): + """The Polaris Hub client used to interact with the Polaris Hub.""" + + # Import it here to prevent circular imports + from polaris.hub.client import PolarisHubClient + + if self._client is None: + self._client = PolarisHubClient() + return self._client + + @property + def uses_zarr(self) -> bool: + """Whether any of the data in this dataset is stored in a Zarr Archive.""" + return self.zarr_root_path is not None + + @property + def zarr_data(self): + """Get the Zarr data. + + This is different from the Zarr Root, because to optimize the efficiency of + data loading, a user can choose to load the data into memory as a numpy array + + Note: General purpose dataloader. + The goal with Polaris is to provide general purpose datasets that serve as good + options for a _wide variety of use cases_. This also implies you should be able to + optimize things further for a specific use case if needed. + """ + if self._zarr_data is not None: + return self._zarr_data + return self.zarr_root + + @property + def zarr_root(self): + """Get the zarr Group object corresponding to the root. + + Opens the zarr archive in read-write mode if it is not already open. + + Note: Different to `zarr_data` + The `zarr_data` attribute references either to the Zarr archive or to a in-memory copy of the data. + See also [`Dataset.load_to_memory`][polaris.dataset.Dataset.load_to_memory]. + """ + + if self._zarr_root is not None: + return self._zarr_root + + if self.zarr_root_path is None: + return None + + # We open the archive in read-only mode if it is saved on the Hub + saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) + + if self._warn_about_remote_zarr: + saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() + + if saved_remote: + logger.warning( + f"You're loading data from a remote location. " + f"To speed up this process, consider caching the dataset first " + f"using {self.__class__.__name__}.cache()" + ) + self._warn_about_remote_zarr = False + + try: + if saved_on_hub: + self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+") + else: + # We use memory mapping by default because our experiments show that it's consistently faster + store = MemoryMappedDirectoryStore(self.zarr_root_path) + self._zarr_root = zarr.open_consolidated(store, mode="r+") + except KeyError as error: + raise InvalidDatasetError( + "A Zarr archive associated with a Polaris dataset has to be consolidated." + ) from error + return self._zarr_root + + @computed_field + @property + def n_rows(self) -> int: + """The number of rows in the dataset.""" + return len(self.rows) + + @computed_field + @property + def n_columns(self) -> int: + """The number of columns in the dataset.""" + return len(self.columns) + + @property + @abc.abstractmethod + def rows(self) -> list: + """Return all row indices for the dataset""" + raise NotImplementedError + + @property + @abc.abstractmethod + def columns(self) -> list: + """Return all columns for the dataset""" + raise NotImplementedError + + @property + @abc.abstractmethod + def dtypes(self) -> dict[str, np.dtype]: + """Return the dtype for each of the columns for the dataset""" + raise NotImplementedError + + def load_to_memory(self): + """ + Load data from zarr files to memeory + + Warning: Make sure the uncompressed dataset fits in-memory. + This method will load the **uncompressed** dataset into memory. Make + sure you actually have enough memory to store the dataset. + """ + data = self.zarr_data + + if not isinstance(data, zarr.Group): + raise TypeError( + "The dataset zarr_root is not a valid Zarr archive. " + "Did you call Dataset.load_to_memory() twice?" + ) + + # NOTE (cwognum): If the dataset fits in memory, the most performant is to use plain NumPy arrays. + # Even if we disable chunking and compression in Zarr. + # For more information, see https://github.com/zarr-developers/zarr-python/issues/1395 + self._zarr_data = load_zarr_group_to_memory(data) + + @abc.abstractmethod + def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = None) -> np.ndarray: + """Since the dataset might contain pointers to external files, data retrieval is more complicated + than just indexing the `table` attribute. This method provides an end-point for seamlessly + accessing the underlying data. + + Args: + row: The row index in the `Dataset.table` attribute + col: The column index in the `Dataset.table` attribute + adapters: The adapters to apply to the data before returning it. + If None, will use the default adapters specified for the dataset. + + Returns: + A numpy array with the data at the specified indices. If the column is a pointer column, + the content of the referenced file is loaded to memory. + """ + raise NotImplementedError + + @abc.abstractmethod + def upload_to_hub( + self, access: Optional[AccessType] = "private", owner: Union[HubOwner, str, None] = None + ): + """Uploads the dataset to the Polaris Hub.""" + raise NotImplementedError + + @classmethod + def from_json(cls, path: str): + """Loads a benchmark from a JSON file. + Overrides the method from the base class to remove the caching dir from the file to load from, + as that should be user dependent. + + Args: + path: Loads a benchmark specification from a JSON file. + """ + with fsspec.open(path, "r") as f: + data = json.load(f) + data.pop("cache_dir", None) + return cls.model_validate(data) + + @abc.abstractmethod + def to_json( + self, + destination: str, + if_exists: ZarrConflictResolution = "replace", + ) -> str: + """ + Save the dataset to a destination directory as a JSON file. + + Args: + destination: The _directory_ to save the associated data to. + if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw + an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + + Returns: + The path to the JSON file. + """ + raise NotImplementedError + + def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) -> str: + """Caches the dataset by downloading all additional data for pointer columns to a local directory. + + Args: + cache_dir: The directory to cache the data to. If not provided, + this will fall back to the `Dataset.cache_dir` attribute + verify_checksum: Whether to verify the checksum of the dataset after caching. + + Returns: + The path to the cache directory. + """ + + if cache_dir is not None: + self.cache_dir = cache_dir + + self.to_json(self.cache_dir) + + if self.uses_zarr: + self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") + self._zarr_root = None + + if verify_checksum: + self.verify_checksum() + + return self.cache_dir + + def size(self): + return self.rows, self.n_columns + + def __getitem__(self, item): + """Allows for indexing the dataset directly""" + ret = self.table.loc[item] + if isinstance(ret, pd.Series): + # Load the data from the pointer columns + + if ret.name in self.table.columns: + # Returning a column, the indices are rows + if self.annotations[ret.name].is_pointer: + ret = np.array([self.get_data(k, ret.name) for k in ret.index]) + + elif len(ret) == self.n_rows: + # Returning a row, the indices are columns + ret = { + k: self.get_data(k, ret.name) if self.annotations[ret.name].is_pointer else ret[k] + for k in ret.index + } + + # Returning a dataframe + if isinstance(ret, pd.DataFrame): + for c in ret.columns: + if self.annotations[c].is_pointer: + ret[c] = [self.get_data(item, c) for item in ret.index] + return ret + + return ret + + @abc.abstractmethod + def _repr_dict_(self) -> dict: + """Utility function for pretty-printing to the command line and jupyter notebooks""" + raise NotImplementedError + + def _repr_html_(self): + """For pretty-printing in Jupyter Notebooks""" + return dict2html(self._repr_dict_()) + + def __len__(self): + return self.n_rows + + def __repr__(self): + return json.dumps(self._repr_dict_(), indent=2) + + def __str__(self): + return self.__repr__() + + def __eq__(self, other): + """Whether two datasets are equal is solely determined by the checksum""" + if not isinstance(other, BaseDataset): + return False + return self.md5sum == other.md5sum + + def __del__(self): + """Close the connection of the client""" + if self._client is not None: + self._client.close() diff --git a/polaris/dataset/_competition_dataset.py b/polaris/dataset/_competition_dataset.py index 2f224c22..7c642e90 100644 --- a/polaris/dataset/_competition_dataset.py +++ b/polaris/dataset/_competition_dataset.py @@ -1,11 +1,10 @@ from pydantic import model_validator -from polaris.dataset import Dataset -from polaris.utils.errors import InvalidCompetitionError -_CACHE_SUBDIR = "datasets" +from polaris.dataset._dataset import DatasetV1 +from polaris.utils.errors import InvalidCompetitionError -class CompetitionDataset(Dataset): +class CompetitionDataset(DatasetV1): """Dataset subclass for Polaris competitions. In addition to the data model and logic of the base Dataset class, diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index f7177349..c06145b0 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,8 +1,6 @@ import json -import uuid from hashlib import md5 -from pathlib import Path -from typing import Dict, List, MutableMapping, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import fsspec import numpy as np @@ -10,41 +8,22 @@ import zarr from datamol.utils import fs as dmfs from loguru import logger -from pydantic import ( - Field, - PrivateAttr, - computed_field, - field_serializer, - field_validator, - model_validator, -) - -from polaris._artifact import BaseArtifactModel +from pydantic import computed_field, field_validator, model_validator + from polaris.dataset._adapters import Adapter +from polaris.dataset._base import BaseDataset from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum -from polaris.dataset.zarr._utils import load_zarr_group_to_memory -from polaris.hub.polarisfs import PolarisFileSystem -from polaris.mixins import ChecksumMixin -from polaris.utils.constants import DEFAULT_CACHE_DIR -from polaris.utils.dict2html import dict2html +from polaris.dataset.zarr import ZarrFileChecksum, compute_zarr_checksum from polaris.utils.errors import InvalidDatasetError -from polaris.utils.types import ( - AccessType, - HttpUrlString, - HubOwner, - SupportedLicenseType, - ZarrConflictResolution, -) +from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution # Constants _SUPPORTED_TABLE_EXTENSIONS = ["parquet"] -_CACHE_SUBDIR = "datasets" _INDEX_SEP = "#" -class Dataset(BaseArtifactModel, ChecksumMixin): - """Basic data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. +class DatasetV1(BaseDataset): + """A Polaris Dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple @@ -58,18 +37,8 @@ class Dataset(BaseArtifactModel, ChecksumMixin): Attributes: table: The core data-structure, storing data-points in a row-wise manner. Can be specified as either a path to a `.parquet` file or a `pandas.DataFrame`. - default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data - for specific columns. - zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to. - readme: Markdown text that can be used to provide a formatted description of the dataset. - If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI - as it provides a rich text editor for writing markdown. - annotations: Each column _can be_ annotated with a [`ColumnAnnotation`][polaris.dataset.ColumnAnnotation] object. - Importantly, this is used to annotate whether a column is a pointer column. - source: The data source, e.g. a DOI, Github repo or URI. - license: The dataset license. Polaris only supports some Creative Commons licenses. See [`SupportedLicenseType`][polaris.utils.types.SupportedLicenseType] for accepted ID values. - curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. - For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. + + For additional meta-data attributes, see the [`BaseDataset`][polaris._dataset.BaseDataset] class. Raises: InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. @@ -77,29 +46,9 @@ class Dataset(BaseArtifactModel, ChecksumMixin): # Public attributes # Data - table: Union[pd.DataFrame, str] - default_adapters: Dict[str, Adapter] = Field(default_factory=dict) - zarr_root_path: Optional[str] = None - - # Additional meta-data - readme: str = "" - annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict) - source: Optional[HttpUrlString] = None - license: Optional[SupportedLicenseType] = None - curation_reference: Optional[HttpUrlString] = None - - # Config - cache_dir: Optional[Path] = None # Where to cache the data to if cache() is called. - - # Private attributes - _zarr_root: Optional[zarr.Group] = PrivateAttr(None) - _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) - _md5sum: Optional[str] = PrivateAttr(None) - _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) - _client = PrivateAttr(None) # Optional[PolarisHubClient] - _warn_about_remote_zarr: bool = PrivateAttr(True) - - @field_validator("table") + table: pd.DataFrame + + @field_validator("table", mode="before") def _validate_table(cls, v): """ If the table is not a dataframe yet, assume it's a path and try load it. @@ -121,17 +70,23 @@ def _validate_table(cls, v): @model_validator(mode="after") @classmethod - def _validate_model(cls, m: "Dataset"): + def _validate_model(cls, m: "DatasetV1"): """Verifies some dependencies between properties""" # Verify that all annotations are for columns that exist - if any(k not in m.table.columns for k in m.annotations): + if any(k not in m.columns for k in m.annotations): raise InvalidDatasetError("There are annotations for columns that do not exist") # Verify that all adapters are for columns that exist - if any(k not in m.table.columns for k in m.default_adapters.keys()): + if any(k not in m.columns for k in m.default_adapters.keys()): raise InvalidDatasetError("There are default adapters for columns that do not exist") + # Set a default for missing annotations and convert strings to Modality + for c in m.columns: + if c not in m.annotations: + m.annotations[c] = ColumnAnnotation() + m.annotations[c].dtype = m.dtypes[c] + has_pointers = any(anno.is_pointer for anno in m.annotations.values()) if has_pointers and m.zarr_root_path is None: raise InvalidDatasetError("A zarr_root_path needs to be specified when there are pointer columns") @@ -139,36 +94,8 @@ def _validate_model(cls, m: "Dataset"): raise InvalidDatasetError( "The zarr_root_path should only be specified when there are pointer columns" ) - - # Set a default for missing annotations and convert strings to Modality - for c in m.table.columns: - if c not in m.annotations: - m.annotations[c] = ColumnAnnotation() - m.annotations[c].dtype = m.table[c].dtype - - # Set the default cache dir if none and make sure it exists - if m.cache_dir is None: - dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) - m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - - m.cache_dir.mkdir(parents=True, exist_ok=True) return m - @field_validator("default_adapters", mode="before") - def _validate_adapters(cls, value): - """Validate the adapters""" - return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} - - @field_serializer("default_adapters") - def _serialize_adapters(self, value: List[Adapter]): - """Serializes the adapters""" - return {k: v.name for k, v in value.items()} - - @field_serializer("cache_dir") - def _serialize_cache_dir(value): - """Serialize the cacha_dir""" - return str(value) - def _compute_checksum(self): """Computes a hash of the dataset. @@ -209,94 +136,6 @@ def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: self.md5sum = self._compute_checksum() return self._zarr_md5sum_manifest - @property - def client(self): - """The Polaris Hub client used to interact with the Polaris Hub.""" - - # Import it here to prevent circular imports - from polaris.hub.client import PolarisHubClient - - if self._client is None: - self._client = PolarisHubClient() - return self._client - - @property - def uses_zarr(self) -> bool: - """Whether any of the data in this dataset is stored in a Zarr Archive.""" - return self.zarr_root_path is not None - - @property - def zarr_data(self): - """Get the Zarr data. - - This is different from the Zarr Root, because to optimize the efficiency of - data loading, a user can choose to load the data into memory as a numpy array - - Note: General purpose dataloader. - The goal with Polaris is to provide general purpose datasets that serve as good - options for a _wide variety of use cases_. This also implies you should be able to - optimize things further for a specific use case if needed. - """ - if self._zarr_data is not None: - return self._zarr_data - return self.zarr_root - - @property - def zarr_root(self): - """Get the zarr Group object corresponding to the root. - - Opens the zarr archive in read-write mode if it is not already open. - - Note: Different to `zarr_data` - The `zarr_data` attribute references either to the Zarr archive or to a in-memory copy of the data. - See also [`Dataset.load_to_memory`][polaris.dataset.Dataset.load_to_memory]. - """ - - if self._zarr_root is not None: - return self._zarr_root - - if self.zarr_root_path is None: - return None - - # We open the archive in read-only mode if it is saved on the Hub - saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) - - if self._warn_about_remote_zarr: - saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() - - if saved_remote: - logger.warning( - f"You're loading data from a remote location. " - f"To speed up this process, consider caching the dataset first " - f"using {self.__class__.__name__}.cache()" - ) - self._warn_about_remote_zarr = False - - try: - if saved_on_hub: - self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+") - else: - # We use memory mapping by default because our experiments show that it's consistently faster - store = MemoryMappedDirectoryStore(self.zarr_root_path) - self._zarr_root = zarr.open_consolidated(store, mode="r+") - except KeyError as error: - raise InvalidDatasetError( - "A Zarr archive associated with a Polaris dataset has to be consolidated." - ) from error - return self._zarr_root - - @computed_field - @property - def n_rows(self) -> int: - """The number of rows in the dataset.""" - return len(self.rows) - - @computed_field - @property - def n_columns(self) -> int: - """The number of columns in the dataset.""" - return len(self.columns) - @property def rows(self) -> list: """Return all row indices for the dataset""" @@ -307,26 +146,10 @@ def columns(self) -> list: """Return all columns for the dataset""" return self.table.columns.tolist() - def load_to_memory(self): - """ - Load data from zarr files to memeory - - Warning: Make sure the uncompressed dataset fits in-memory. - This method will load the **uncompressed** dataset into memory. Make - sure you actually have enough memory to store the dataset. - """ - data = self.zarr_data - - if not isinstance(data, zarr.Group): - raise TypeError( - "The dataset zarr_root is not a valid Zarr archive. " - "Did you call Dataset.load_to_memory() twice?" - ) - - # NOTE (cwognum): If the dataset fits in memory, the most performant is to use plain NumPy arrays. - # Even if we disable chunking and compression in Zarr. - # For more information, see https://github.com/zarr-developers/zarr-python/issues/1395 - self._zarr_data = load_zarr_group_to_memory(data) + @property + def dtypes(self) -> dict[str, np.dtype]: + """Return the dtype for each of the columns for the dataset""" + return {col: self.table[col].dtype for col in self.columns} def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = None) -> np.ndarray: """Since the dataset might contain pointers to external files, data retrieval is more complicated @@ -378,20 +201,6 @@ def upload_to_hub( """ self.client.upload_dataset(self, access=access, owner=owner) - @classmethod - def from_json(cls, path: str): - """Loads a benchmark from a JSON file. - Overrides the method from the base class to remove the caching dir from the file to load from, - as that should be user dependent. - - Args: - path: Loads a benchmark specification from a JSON file. - """ - with fsspec.open(path, "r") as f: - data = json.load(f) - data.pop("cache_dir", None) - return cls.model_validate(data) - def to_json( self, destination: str, @@ -473,9 +282,6 @@ def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) - return self.cache_dir - def size(self): - return self.rows, self.n_columns - def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: """ Paths can have an additional index appended to them. @@ -494,58 +300,7 @@ def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: raise ValueError(f"Invalid index format: {index}") return path, index - def __getitem__(self, item): - """Allows for indexing the dataset directly""" - ret = self.table.loc[item] - if isinstance(ret, pd.Series): - # Load the data from the pointer columns - - if ret.name in self.table.columns: - # Returning a column, the indices are rows - if self.annotations[ret.name].is_pointer: - ret = np.array([self.get_data(k, ret.name) for k in ret.index]) - - elif len(ret) == self.n_rows: - # Returning a row, the indices are columns - ret = { - k: self.get_data(k, ret.name) if self.annotations[ret.name].is_pointer else ret[k] - for k in ret.index - } - - # Returning a dataframe - if isinstance(ret, pd.DataFrame): - for c in ret.columns: - if self.annotations[c].is_pointer: - ret[c] = [self.get_data(item, c) for item in ret.index] - return ret - - return ret - def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"table", "zarr_md5sum_manifest"}) return repr_dict - - def _repr_html_(self): - """For pretty-printing in Jupyter Notebooks""" - return dict2html(self._repr_dict_()) - - def __len__(self): - return self.n_rows - - def __repr__(self): - return json.dumps(self._repr_dict_(), indent=2) - - def __str__(self): - return self.__repr__() - - def __eq__(self, other): - """Whether two datasets are equal is solely determined by the checksum""" - if not isinstance(other, Dataset): - return False - return self.md5sum == other.md5sum - - def __del__(self): - """Close the connection of the client""" - if self._client is not None: - self._client.close() diff --git a/polaris/dataset/_factory.py b/polaris/dataset/_factory.py index c69cdba8..dfd550b0 100644 --- a/polaris/dataset/_factory.py +++ b/polaris/dataset/_factory.py @@ -6,12 +6,12 @@ import zarr from loguru import logger -from polaris.dataset import ColumnAnnotation, Dataset +from polaris.dataset import ColumnAnnotation, DatasetV1 from polaris.dataset._adapters import Adapter from polaris.dataset.converters import Converter, PDBConverter, SDFConverter, ZarrConverter -def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> Dataset: +def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> DatasetV1: """ This function is a convenience function to create a dataset from a file. @@ -29,7 +29,7 @@ def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> def create_dataset_from_files( paths: List[str], zarr_root_path: Optional[str] = None, axis: Literal[0, 1, "index", "columns"] = 0 -) -> Dataset: +) -> DatasetV1: """ This function is a convenience function to create a dataset from multiple files. @@ -265,10 +265,10 @@ def add_from_files(self, paths: List[str], axis: Literal[0, 1, "index", "columns for path in paths: self.add_from_file(path) - def build(self) -> Dataset: + def build(self) -> DatasetV1: """Returns a Dataset based on the current state of the factory.""" zarr.consolidate_metadata(self.zarr_root.store) - return Dataset( + return DatasetV1( table=self._table, annotations=self._annotations, default_adapters=self._adapters, diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 448abbff..9280e454 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -2,7 +2,7 @@ import numpy as np -from polaris.dataset import Dataset +from polaris.dataset import DatasetV1 from polaris.dataset._adapters import Adapter from polaris.utils.errors import TestAccessError from polaris.utils.types import DatapointType @@ -61,7 +61,7 @@ class Subset: def __init__( self, - dataset: Dataset, + dataset: DatasetV1, indices: List[Union[int, Sequence[int]]], input_cols: Union[List[str], str], target_cols: Union[List[str], str], diff --git a/polaris/experimental/__init__.py b/polaris/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py new file mode 100644 index 00000000..a7022581 --- /dev/null +++ b/polaris/experimental/_dataset_v2.py @@ -0,0 +1,5 @@ +from polaris.dataset._base import BaseDataset + + +class DatasetV2(BaseDataset): + """Dataset subclass for Polaris""" diff --git a/polaris/hub/client.py b/polaris/hub/client.py index dca7c239..468ed3b7 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -2,8 +2,7 @@ import ssl from hashlib import md5 from io import BytesIO -from typing import Callable, get_args -from typing import Union +from typing import Callable, Union, get_args from urllib.parse import urljoin import certifi @@ -23,14 +22,12 @@ MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) -from polaris.dataset import Dataset -from polaris.evaluate import BenchmarkResults +from polaris.competition import CompetitionSpecification +from polaris.dataset import CompetitionDataset, DatasetV1 +from polaris.evaluate import BenchmarkResults, CompetitionResults from polaris.evaluate._results import CompetitionPredictions from polaris.hub.external_auth_client import ExternalAuthClient from polaris.hub.oauth import CachedTokenAuth -from polaris.dataset import CompetitionDataset -from polaris.evaluate import CompetitionResults -from polaris.competition import CompetitionSpecification from polaris.hub.polarisfs import PolarisFileSystem from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import ProgressIndicator, tmp_attribute_change @@ -296,7 +293,7 @@ def get_dataset( owner: str | HubOwner, name: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr", - ) -> Dataset: + ) -> DatasetV1: """Load a standard dataset from the Polaris Hub. Args: @@ -316,7 +313,7 @@ def _get_dataset( name: str, artifact_type: ArtifactSubtype, verify_checksum: bool = True, - ) -> Dataset: + ) -> DatasetV1: """Loads either a standard or competition dataset from Polaris Hub Args: @@ -360,7 +357,7 @@ def _get_dataset( dataset = CompetitionDataset(**response) md5Sum = response["maskedMd5Sum"] else: - dataset = Dataset(**response) + dataset = DatasetV1(**response) md5Sum = response["md5Sum"] if should_verify_checksum(verify_checksum, dataset): @@ -535,7 +532,7 @@ def upload_results( def upload_dataset( self, - dataset: Dataset, + dataset: DatasetV1, access: AccessType = "private", timeout: TimeoutTypes = (10, 200), owner: HubOwner | str | None = None, @@ -548,7 +545,7 @@ def upload_dataset( def _upload_dataset( self, - dataset: Dataset, + dataset: DatasetV1, artifact_type: ArtifactSubtype, access: AccessType = "private", timeout: TimeoutTypes = (10, 200), diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 6e152f68..797f7b78 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -7,13 +7,13 @@ MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) -from polaris.dataset import Dataset, create_dataset_from_file +from polaris.dataset import DatasetV1, create_dataset_from_file from polaris.hub.client import PolarisHubClient from polaris.utils.misc import should_verify_checksum from polaris.utils.types import ChecksumStrategy -def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> Dataset: +def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> DatasetV1: """ Loads a Polaris dataset. @@ -41,7 +41,7 @@ def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_z # Load from local file if extension == "json": - dataset = Dataset.from_json(path) + dataset = DatasetV1.from_json(path) else: dataset = create_dataset_from_file(path) diff --git a/polaris/mixins/_checksum.py b/polaris/mixins/_checksum.py index 8fccac35..6991e7f5 100644 --- a/polaris/mixins/_checksum.py +++ b/polaris/mixins/_checksum.py @@ -4,8 +4,6 @@ from loguru import logger from pydantic import BaseModel, PrivateAttr, computed_field -from polaris.utils.errors import PolarisChecksumError - class ChecksumMixin(BaseModel, abc.ABC): """ @@ -66,6 +64,9 @@ def verify_checksum(self, md5sum: str | None = None): self.md5sum = self._compute_checksum() if self.md5sum != md5sum: + # Imported here to prevent circular import + from polaris.utils.errors import PolarisChecksumError + raise PolarisChecksumError( f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" ) diff --git a/polaris/utils/misc.py b/polaris/utils/misc.py index 9a8199eb..2622fde4 100644 --- a/polaris/utils/misc.py +++ b/polaris/utils/misc.py @@ -3,7 +3,7 @@ from polaris.utils.types import ChecksumStrategy, SlugCompatibleStringType if TYPE_CHECKING: - from polaris.dataset import Dataset + from polaris.dataset import DatasetV1 def listit(t: Any): @@ -21,7 +21,7 @@ def sluggify(sluggable: SlugCompatibleStringType): return sluggable.lower().replace("_", "-") -def should_verify_checksum(strategy: ChecksumStrategy, dataset: "Dataset") -> bool: +def should_verify_checksum(strategy: ChecksumStrategy, dataset: "DatasetV1") -> bool: """ Determines whether a checksum should be verified. """ diff --git a/tests/conftest.py b/tests/conftest.py index d9190f58..dfefc00a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ SingleTaskBenchmarkSpecification, ) from polaris.competition import CompetitionSpecification -from polaris.dataset import ColumnAnnotation, Dataset, CompetitionDataset +from polaris.dataset import ColumnAnnotation, CompetitionDataset, DatasetV1 from polaris.utils.types import HubOwner @@ -109,7 +109,7 @@ def test_user_owner(): @pytest.fixture(scope="function") def test_dataset(test_data, test_org_owner): - dataset = Dataset( + dataset = DatasetV1( table=test_data, name="test-dataset", source="https://www.example.com", diff --git a/tests/test_dataset.py b/tests/test_dataset.py index db4336e1..2628c609 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -6,7 +6,7 @@ import zarr from datamol.utils import fs -from polaris.dataset import Dataset, Subset, create_dataset_from_file +from polaris.dataset import DatasetV1, Subset, create_dataset_from_file from polaris.loader import load_dataset @@ -27,7 +27,7 @@ def test_load_data(tmp_path, with_slice, with_caching): path = "A#0:5" if with_slice else "A#0" table = pd.DataFrame({"A": [path]}, index=[0]) - dataset = Dataset(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path) + dataset = DatasetV1(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path) if with_caching: dataset.cache(fs.join(tmpdir, "cache")) @@ -56,28 +56,28 @@ def test_dataset_checksum(test_dataset): # Without any changes, same hash kwargs = test_dataset.model_dump() - assert Dataset(**kwargs) == test_dataset + assert DatasetV1(**kwargs) == test_dataset # With unimportant changes, same hash kwargs["name"] = "changed" kwargs["description"] = "changed" kwargs["source"] = "https://changed.com" - assert Dataset(**kwargs) == test_dataset + assert DatasetV1(**kwargs) == test_dataset # Check sensitivity to the row and column ordering kwargs["table"] = kwargs["table"].iloc[::-1] kwargs["table"] = kwargs["table"][kwargs["table"].columns[::-1]] - assert Dataset(**kwargs) == test_dataset + assert DatasetV1(**kwargs) == test_dataset # Without any changes, but different hash - dataset = Dataset(**kwargs) + dataset = DatasetV1(**kwargs) dataset._md5sum = "invalid" assert dataset != test_dataset # With changes, but same hash kwargs["md5sum"] = test_dataset.md5sum kwargs["table"] = kwargs["table"].iloc[:-1] - assert Dataset(**kwargs) != test_dataset + assert DatasetV1(**kwargs) != test_dataset def test_dataset_from_zarr(zarr_archive, tmpdir): @@ -97,7 +97,7 @@ def test_dataset_from_json(test_dataset, tmpdir): path = fs.join(str(tmpdir), "dataset.json") - new_dataset = Dataset.from_json(path) + new_dataset = DatasetV1.from_json(path) assert test_dataset == new_dataset new_dataset = load_dataset(path) @@ -117,7 +117,7 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmpdir): dataset = create_dataset_from_file(archive, zarr_dir) path = dataset.to_json(json_dir) - new_dataset = Dataset.from_json(path) + new_dataset = DatasetV1.from_json(path) assert dataset == new_dataset new_dataset = load_dataset(path) @@ -140,7 +140,7 @@ def test_dataset_caching(zarr_archive, tmpdir): def test_dataset_index(): """Small test to check whether the dataset resets its index.""" df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["X", "Y", "Z"]) - dataset = Dataset(table=df) + dataset = DatasetV1(table=df) subset = Subset(dataset=dataset, indices=[1], input_cols=["A"], target_cols=["B"]) assert next(iter(subset)) == (np.array([2]), np.array([5])) diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 6bb8be46..8440bdbb 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1,17 +1,18 @@ import os -import pytest + import numpy as np import pandas as pd +import pytest import polaris as po from polaris.benchmark import ( MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) +from polaris.dataset import DatasetV1 from polaris.evaluate._metric import Metric from polaris.evaluate._results import BenchmarkResults from polaris.utils.types import HubOwner -from polaris.dataset import Dataset def test_result_to_json(tmpdir: str, test_user_owner: HubOwner): @@ -150,7 +151,7 @@ def test_absolute_average_fold_error(): def test_metric_y_types( - tmpdir: str, test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification, test_data: Dataset + tmpdir: str, test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification, test_data: DatasetV1 ): # here we use train split for testing purpose. _, test = test_single_task_benchmark_clf.get_train_test_split() From 81cee1885d44f8ad5bf6c58aa6b8ed428c11e128 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Mon, 26 Aug 2024 20:06:45 -0400 Subject: [PATCH 02/20] Skeleton structure for tests and Dataset V2. Small changes to shared API --- polaris/dataset/_base.py | 10 ++-- polaris/dataset/_dataset.py | 38 ++++--------- polaris/experimental/_dataset_v2.py | 84 +++++++++++++++++++++++++++++ tests/conftest.py | 18 +++++++ tests/test_dataset_v2.py | 34 ++++++++++++ 5 files changed, 151 insertions(+), 33 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index c13c9655..52be9c30 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd import zarr -from datamol.utils import fs as dmfs from loguru import logger from pydantic import ( Field, @@ -302,6 +301,7 @@ def to_json( self, destination: str, if_exists: ZarrConflictResolution = "replace", + load_zarr_from_new_location: bool = False, ) -> str: """ Save the dataset to a destination directory as a JSON file. @@ -310,6 +310,8 @@ def to_json( destination: The _directory_ to save the associated data to. if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + load_zarr_from_new_location: Whether to update the current instance to load data from the location + the data is saved to. Only relevant for Zarr-datasets. Returns: The path to the JSON file. @@ -331,11 +333,7 @@ def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) - if cache_dir is not None: self.cache_dir = cache_dir - self.to_json(self.cache_dir) - - if self.uses_zarr: - self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") - self._zarr_root = None + self.to_json(self.cache_dir, load_zarr_from_new_location=True) if verify_checksum: self.verify_checksum() diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index c06145b0..ef5eae70 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,6 +1,6 @@ import json from hashlib import md5 -from typing import List, Optional, Tuple, Union +from typing import ClassVar, List, Literal, Optional, Tuple, Union import fsspec import numpy as np @@ -48,6 +48,8 @@ class DatasetV1(BaseDataset): # Data table: pd.DataFrame + version: ClassVar[Literal[1]] = 1 + @field_validator("table", mode="before") def _validate_table(cls, v): """ @@ -205,6 +207,7 @@ def to_json( self, destination: str, if_exists: ZarrConflictResolution = "replace", + load_zarr_from_new_location: bool = False, ) -> str: """ Save the dataset to a destination directory as a JSON file. @@ -222,6 +225,8 @@ def to_json( destination: The _directory_ to save the associated data to. if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + load_zarr_from_new_location: Whether to update the current instance to load data from the location + the data is saved to. Only relevant for Zarr-datasets. Returns: The path to the JSON file. @@ -250,38 +255,17 @@ def to_json( if_exists=if_exists, ) + if load_zarr_from_new_location: + self.zarr_root_path = new_zarr_root_path + self._zarr_root = None + self._zarr_data = None + self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: json.dump(serialized, f) return dataset_path - def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) -> str: - """Caches the dataset by downloading all additional data for pointer columns to a local directory. - - Args: - cache_dir: The directory to cache the data to. If not provided, - this will fall back to the `Dataset.cache_dir` attribute - verify_checksum: Whether to verify the checksum of the dataset after caching. - - Returns: - The path to the cache directory. - """ - - if cache_dir is not None: - self.cache_dir = cache_dir - - self.to_json(self.cache_dir) - - if self.uses_zarr: - self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") - self._zarr_root = None - - if verify_checksum: - self.verify_checksum() - - return self.cache_dir - def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: """ Paths can have an additional index appended to them. diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index a7022581..0440f5fe 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -1,5 +1,89 @@ +from typing import ClassVar, List, Literal, Optional + +import numpy as np +from pydantic import computed_field + +from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset +from polaris.dataset.zarr._checksum import ZarrFileChecksum +from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution class DatasetV2(BaseDataset): """Dataset subclass for Polaris""" + + version: ClassVar[Literal[2]] = 2 + + @property + def rows(self) -> list: + """Return all row indices for the dataset""" + raise NotImplementedError + + @property + def columns(self) -> list: + """Return all columns for the dataset""" + raise NotImplementedError + + @property + def dtypes(self) -> dict[str, np.dtype]: + """Return the dtype for each of the columns for the dataset""" + raise NotImplementedError + + @computed_field + @property + def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: + """ + The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. + If the dataset doesn't use Zarr, this will simply return an empty list. + """ + raise NotImplementedError + + def _compute_checksum(self) -> str: + """Compute the checksum of the dataset.""" + raise NotImplementedError + + def get_data(self, row: str | int, col: str, adapters: List[Adapter] | None = None) -> np.ndarray: + """Since the dataset might contain pointers to external files, data retrieval is more complicated + than just indexing the `table` attribute. This method provides an end-point for seamlessly + accessing the underlying data. + + Args: + row: The row index in the `Dataset.table` attribute + col: The column index in the `Dataset.table` attribute + adapters: The adapters to apply to the data before returning it. + If None, will use the default adapters specified for the dataset. + + Returns: + A numpy array with the data at the specified indices. If the column is a pointer column, + the content of the referenced file is loaded to memory. + """ + raise NotImplementedError + + def upload_to_hub(self, access: Optional[AccessType] = "private", owner: HubOwner | str | None = None): + """Uploads the dataset to the Polaris Hub.""" + raise NotImplementedError + + def to_json( + self, + destination: str, + if_exists: ZarrConflictResolution = "replace", + load_zarr_from_new_location: bool = False, + ) -> str: + """ + Save the dataset to a destination directory as a JSON file. + + Args: + destination: The _directory_ to save the associated data to. + if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw + an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. + load_zarr_from_new_location: Whether to update the current instance to load data from the location + the data is saved to. Only relevant for Zarr-datasets. + + Returns: + The path to the JSON file. + """ + raise NotImplementedError + + def _repr_dict_(self) -> dict: + """Utility function for pretty-printing to the command line and jupyter notebooks""" + raise NotImplementedError diff --git a/tests/conftest.py b/tests/conftest.py index dfefc00a..c14b5188 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ ) from polaris.competition import CompetitionSpecification from polaris.dataset import ColumnAnnotation, CompetitionDataset, DatasetV1 +from polaris.experimental._dataset_v2 import DatasetV2 from polaris.utils.types import HubOwner @@ -124,6 +125,23 @@ def test_dataset(test_data, test_org_owner): return dataset +@pytest.fixture(scope="function") +def test_dataset_v2(zarr_archive, test_org_owner): + dataset = DatasetV2( + name="test-dataset-v2", + source="https://www.example.com", + annotations={"A": ColumnAnnotation(user_attributes={"unit": "kcal/mol"})}, + tags=["tagA", "tagB"], + user_attributes={"attributeA": "valueA", "attributeB": "valueB"}, + owner=test_org_owner, + license="CC-BY-4.0", + curation_reference="https://www.example.com", + zarr_root_path=zarr_archive, + ) + check_version(dataset) + return dataset + + @pytest.fixture(scope="function") def test_competition_dataset(test_data, test_org_owner): dataset = CompetitionDataset( diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index e69de29b..60e2c335 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -0,0 +1,34 @@ +def test_dataset_v2_get_columns(test_dataset_v2): + pass + + +def test_dataset_v2_get_rows(test_dataset_v2): + pass + + +def test_dataset_v2_get_data(test_dataset_v2): + pass + + +def test_dataset_v2_with_subset(test_dataset_v2): + pass + + +def test_dataset_v2_load_to_memory(test_dataset_v2): + pass + + +def test_dataset_v2_checksum(test_dataset_v2): + pass + + +def test_dataset_v2_serialization(test_dataset_v2): + pass + + +def test_dataset_v2_caching(test_dataset_v2): + pass + + +def test_dataset_v1_v2_compatibility(test_dataset_v2): + pass From 613dcb2ca198f98d82fe881b55368cf8fae62127 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Mon, 26 Aug 2024 20:33:12 -0400 Subject: [PATCH 03/20] Implemented the test cases Test-driven development! Yeah --- tests/conftest.py | 4 +- tests/test_dataset.py | 2 +- tests/test_dataset_v2.py | 126 ++++++++++++++++++++++++++++++++++----- 3 files changed, 113 insertions(+), 19 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c14b5188..a0e19c23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,7 +109,7 @@ def test_user_owner(): @pytest.fixture(scope="function") -def test_dataset(test_data, test_org_owner): +def test_dataset(test_data, test_org_owner) -> DatasetV1: dataset = DatasetV1( table=test_data, name="test-dataset", @@ -126,7 +126,7 @@ def test_dataset(test_data, test_org_owner): @pytest.fixture(scope="function") -def test_dataset_v2(zarr_archive, test_org_owner): +def test_dataset_v2(zarr_archive, test_org_owner) -> DatasetV2: dataset = DatasetV2( name="test-dataset-v2", source="https://www.example.com", diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2628c609..ff409eb1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -131,7 +131,7 @@ def test_dataset_caching(zarr_archive, tmpdir): cached_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original2")) assert original_dataset == cached_dataset - cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath) + cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath, verify_checksum=True) assert cached_dataset.zarr_root_path.startswith(cache_dir) assert cached_dataset == original_dataset diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 60e2c335..b74fe829 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,34 +1,128 @@ +from time import perf_counter + +import numpy as np + +from polaris.dataset import Subset, zarr +from polaris.experimental._dataset_v2 import DatasetV2 + + def test_dataset_v2_get_columns(test_dataset_v2): - pass + assert set(test_dataset_v2.columns) == {"A", "B"} -def test_dataset_v2_get_rows(test_dataset_v2): - pass +def test_dataset_v2_get_rows(test_dataset_v2, zarr_archive): + assert set(test_dataset_v2.rows) == set(range(100)) -def test_dataset_v2_get_data(test_dataset_v2): - pass +def test_dataset_v2_get_data(test_dataset_v2, zarr_archive): + indices = np.random.randint(0, len(test_dataset_v2), 5) + for idx in indices: + assert np.array_equal(test_dataset_v2.get_data(row=idx, col="A"), zarr_archive["A"][idx]) + assert np.array_equal(test_dataset_v2.get_data(row=idx, col="B"), zarr_archive["B"][idx]) -def test_dataset_v2_with_subset(test_dataset_v2): - pass +def test_dataset_v2_with_subset(test_dataset_v2, zarr_archive): + indices = np.random.randint(0, len(test_dataset_v2), 5) + subset = Subset(test_dataset_v2, indices, "A", "B") + for i, (x, y) in enumerate(subset): + idx = indices[i] + assert np.array_equal(x, zarr_archive["A"][idx]) + assert np.array_equal(y, zarr_archive["B"][idx]) def test_dataset_v2_load_to_memory(test_dataset_v2): - pass + subset = Subset( + dataset=test_dataset_v2, + indices=range(100), + input_cols=["A"], + target_cols=["B"], + ) + + t1 = perf_counter() + for x in subset: + pass + d1 = perf_counter() - t1 + + test_dataset_v2.load_to_memory() + + t2 = perf_counter() + for x in subset: + pass + d2 = perf_counter() - t2 + + assert d2 < d1 + + +def test_dataset_v2_checksum(test_dataset_v2, tmpdir): + # Make sure the `md5sum` is part of the model dump even if not initiated yet. + # This is important for uploads to the Hub. + assert test_dataset_v2._md5sum is None + assert "md5sum" in test_dataset_v2.model_dump() + + # (1) Without any changes, same hash + kwargs = test_dataset_v2.model_dump() + assert DatasetV2(**kwargs) == test_dataset_v2 + + # (2) With unimportant changes, same hash + kwargs["name"] = "changed" + kwargs["description"] = "changed" + kwargs["source"] = "https://changed.com" + assert DatasetV2(**kwargs) == test_dataset_v2 + + # (3) Without any changes, but different hash + dataset = DatasetV2(**kwargs) + dataset._md5sum = "invalid" + assert dataset != test_dataset_v2 + + # (4) With changes, but same hash + # Reset hash + kwargs["md5sum"] = test_dataset_v2.md5sum + + # Copy Zarr data to local + dataset = DatasetV2(**kwargs) + save_dir = tmpdir.join("save_dir") + dataset.to_json(save_dir, load_zarr_from_new_location=True) + + # Make changes to Zarr archive copy + root = zarr.open(dataset.zarr_root_path, "w") + root["A"][0] = np.zeros(2048) + + # Checksum should be different + assert dataset != test_dataset_v2 + + +def test_dataset_v2_serialization(test_dataset_v2, tmpdir): + save_dir = tmpdir.join("save_dir") + path = test_dataset_v2.to_json(save_dir) + new_dataset = DatasetV2.from_json(path) + assert test_dataset_v2 == new_dataset + +def test_dataset_v2_caching(test_dataset_v2, tmpdir): + cache_dir = tmpdir.join("cache").strpath + test_dataset_v2.cache(cache_dir, verify_checksum=True) + test_dataset_v2.zarr_root_path.startswith(cache_dir) -def test_dataset_v2_checksum(test_dataset_v2): - pass +def test_dataset_v1_v2_compatibility(test_dataset, tmpdir): + # A DataFrame is ultimately a collection of labeled numpy arrays + # We can thus also saved these same arrays to a Zarr archive + df = test_dataset.table -def test_dataset_v2_serialization(test_dataset_v2): - pass + path = tmpdir.join("data/v1v2.zarr") + root = zarr.open(path, "w") + for c in df.columns: + root.array(c, data=df[c].values) -def test_dataset_v2_caching(test_dataset_v2): - pass + kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"]) + dataset = DatasetV2(**kwargs, zarr_root_path=path) + subset_1 = Subset(dataset=dataset, indices=range(100), input_cols=["smiles"], target_cols=["calc"]) + subset_2 = Subset(dataset=dataset, indices=range(100), input_cols=["smiles"], target_cols=["calc"]) -def test_dataset_v1_v2_compatibility(test_dataset_v2): - pass + for idx in range(5): + x1, y1 = subset_1[idx] + x2, y2 = subset_2[idx] + assert x1 == x2 + assert y1 == y2 From 27c73abf41bd9a7b0d07cfc47e3422b1e884f437 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Mon, 26 Aug 2024 21:29:49 -0400 Subject: [PATCH 04/20] Basic test cases passed Now the fun starts... --- polaris/benchmark/_base.py | 1 + polaris/dataset/_base.py | 10 ++-- polaris/dataset/_dataset.py | 15 ++++-- polaris/dataset/_subset.py | 2 +- polaris/dataset/zarr/_checksum.py | 4 +- polaris/experimental/_dataset_v2.py | 75 +++++++++++++++++++++++++---- tests/test_dataset_v2.py | 30 +++++++----- 7 files changed, 104 insertions(+), 33 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index e2ec4902..241e8547 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -345,6 +345,7 @@ def n_classes(self) -> dict[str, int]: target_type = self.target_types[target] if target_type is None or target_type == TargetType.REGRESSION: continue + # TODO: Don't use table attribute n_classes[target] = self.dataset.table.loc[:, target].nunique() return n_classes diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 52be9c30..da2de067 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -115,10 +115,12 @@ def _serialize_adapters(self, value: List[Adapter]): """Serializes the adapters""" return {k: v.name for k, v in value.items()} - @field_serializer("cache_dir") - def _serialize_cache_dir(value): - """Serialize the cacha_dir""" - return str(value) + @field_serializer("cache_dir", "zarr_root_path") + def _serialize_paths(value): + """Serialize the paths""" + if value is not None: + value = str(value) + return value @computed_field @property diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index ef5eae70..2e92a6e4 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,5 +1,6 @@ import json from hashlib import md5 +from pathlib import Path from typing import ClassVar, List, Literal, Optional, Tuple, Union import fsspec @@ -120,7 +121,7 @@ def _compute_checksum(self): # If the Zarr archive exists, we hash its contents too. if self.uses_zarr: zarr_hash, self._zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) - hash_fn.update(zarr_hash.encode()) + hash_fn.update(zarr_hash.digest.encode()) checksum = hash_fn.hexdigest() return checksum @@ -231,10 +232,12 @@ def to_json( Returns: The path to the JSON file. """ - dmfs.mkdir(destination, exist_ok=True) - table_path = dmfs.join(destination, "table.parquet") - dataset_path = dmfs.join(destination, "dataset.json") - new_zarr_root_path = dmfs.join(destination, "data.zarr") + destination = Path(destination) + destination.mkdir(exist_ok=True, parents=True) + + table_path = str(destination / "table.parquet") + dataset_path = str(destination / "dataset.json") + new_zarr_root_path = str(destination / "data.zarr") # Lu: Avoid serilizing and sending None to hub app. serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) @@ -242,6 +245,8 @@ def to_json( # Copy over Zarr data to the destination if self.uses_zarr: + serialized["zarrRootPath"] = new_zarr_root_path + self._warn_about_remote_zarr = False logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 9280e454..7b11822f 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -80,7 +80,7 @@ def __init__( # NOTE (cwognum): Note to future self. As we're starting to think about competition-style benchmarks, # we will likely split up datasets. In that case, this default iloc_to_loc mapping won't work. # By that time, we should probably be able to overwrite this mapping. - self._iloc_to_loc = self.dataset.table.index + self._iloc_to_loc = self.dataset.rows # For the iterator implementation self._pointer = 0 diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index a06f9491..8a1a8348 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -52,7 +52,7 @@ ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" -def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, List["ZarrFileChecksum"]]: +def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", List["ZarrFileChecksum"]]: r""" Implements an algorithm to compute the Zarr checksum. @@ -145,7 +145,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, List["ZarrFileCheck zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) # Compute digest - return tree.process().digest, zarr_md5sum_manifest + return tree.process(), zarr_md5sum_manifest class ZarrFileChecksum(BaseModel): diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index 0440f5fe..df892e19 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -1,11 +1,16 @@ +import json +from pathlib import Path from typing import ClassVar, List, Literal, Optional +import fsspec import numpy as np +import zarr +from loguru import logger from pydantic import computed_field from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset -from polaris.dataset.zarr._checksum import ZarrFileChecksum +from polaris.dataset.zarr._checksum import ZarrFileChecksum, compute_zarr_checksum from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution @@ -17,17 +22,17 @@ class DatasetV2(BaseDataset): @property def rows(self) -> list: """Return all row indices for the dataset""" - raise NotImplementedError + return np.arange(len(self.zarr_root[self.columns[0]])) @property def columns(self) -> list: """Return all columns for the dataset""" - raise NotImplementedError + return list(self.zarr_root.keys()) @property def dtypes(self) -> dict[str, np.dtype]: """Return the dtype for each of the columns for the dataset""" - raise NotImplementedError + return {col: self.zarr_root[col].dtype for col in self.columns} @computed_field @property @@ -36,13 +41,18 @@ def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. If the dataset doesn't use Zarr, this will simply return an empty list. """ - raise NotImplementedError + if len(self._zarr_md5sum_manifest) == 0 and not self.has_md5sum: + # The manifest is set as an instance variable + # as a side-effect of the compute_checksum method + self.md5sum = self._compute_checksum() + return self._zarr_md5sum_manifest def _compute_checksum(self) -> str: """Compute the checksum of the dataset.""" - raise NotImplementedError + zarr_hash, self._zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) + return zarr_hash.md5 - def get_data(self, row: str | int, col: str, adapters: List[Adapter] | None = None) -> np.ndarray: + def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> np.ndarray: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly accessing the underlying data. @@ -57,10 +67,24 @@ def get_data(self, row: str | int, col: str, adapters: List[Adapter] | None = No A numpy array with the data at the specified indices. If the column is a pointer column, the content of the referenced file is loaded to memory. """ - raise NotImplementedError + # Fetch adapters for dataset and a given column + adapters = adapters or self.default_adapters + adapter = adapters.get(col) + + # Get the data + arr = self.zarr_root[col][row] + + # Adapt the input to the specified format + if adapter is not None: + arr = adapter(arr) + + return arr def upload_to_hub(self, access: Optional[AccessType] = "private", owner: HubOwner | str | None = None): """Uploads the dataset to the Polaris Hub.""" + + # NOTE (cwognum): Leaving this for a later PR, because I want + # to do it simultaneously with a PR on the Hub side. raise NotImplementedError def to_json( @@ -82,8 +106,39 @@ def to_json( Returns: The path to the JSON file. """ - raise NotImplementedError + destination = Path(destination) + destination.mkdir(exist_ok=True, parents=True) + + dataset_path = destination / "dataset.json" + new_zarr_root_path = destination / "data.zarr" + + # Lu: Avoid serilizing and sending None to hub app. + serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) + serialized["zarrRootPath"] = str(new_zarr_root_path) + + # Copy over Zarr data to the destination + self._warn_about_remote_zarr = False + + logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") + dest = zarr.open(new_zarr_root_path, "w") + + zarr.copy_store( + source=self.zarr_root.store.store, + dest=dest.store, + log=logger.debug, + if_exists=if_exists, + ) + + if load_zarr_from_new_location: + self.zarr_root_path = new_zarr_root_path + self._zarr_root = None + self._zarr_data = None + + with fsspec.open(dataset_path, "w") as f: + json.dump(serialized, f) + return str(dataset_path) def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" - raise NotImplementedError + repr_dict = self.model_dump(exclude={"zarr_md5sum_manifest"}) + return repr_dict diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index b74fe829..a5a242bb 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,9 +1,13 @@ from time import perf_counter +import numcodecs import numpy as np +import pytest +import zarr -from polaris.dataset import Subset, zarr +from polaris.dataset import Subset from polaris.experimental._dataset_v2 import DatasetV2 +from polaris.utils.errors import PolarisChecksumError def test_dataset_v2_get_columns(test_dataset_v2): @@ -15,19 +19,21 @@ def test_dataset_v2_get_rows(test_dataset_v2, zarr_archive): def test_dataset_v2_get_data(test_dataset_v2, zarr_archive): + root = zarr.open(zarr_archive, "r") indices = np.random.randint(0, len(test_dataset_v2), 5) for idx in indices: - assert np.array_equal(test_dataset_v2.get_data(row=idx, col="A"), zarr_archive["A"][idx]) - assert np.array_equal(test_dataset_v2.get_data(row=idx, col="B"), zarr_archive["B"][idx]) + assert np.array_equal(test_dataset_v2.get_data(row=idx, col="A"), root["A"][idx]) + assert np.array_equal(test_dataset_v2.get_data(row=idx, col="B"), root["B"][idx]) def test_dataset_v2_with_subset(test_dataset_v2, zarr_archive): + root = zarr.open(zarr_archive, "r") indices = np.random.randint(0, len(test_dataset_v2), 5) subset = Subset(test_dataset_v2, indices, "A", "B") for i, (x, y) in enumerate(subset): idx = indices[i] - assert np.array_equal(x, zarr_archive["A"][idx]) - assert np.array_equal(y, zarr_archive["B"][idx]) + assert np.array_equal(x, root["A"][idx]) + assert np.array_equal(y, root["B"][idx]) def test_dataset_v2_load_to_memory(test_dataset_v2): @@ -84,11 +90,12 @@ def test_dataset_v2_checksum(test_dataset_v2, tmpdir): dataset.to_json(save_dir, load_zarr_from_new_location=True) # Make changes to Zarr archive copy - root = zarr.open(dataset.zarr_root_path, "w") + root = zarr.open(dataset.zarr_root_path, "a") root["A"][0] = np.zeros(2048) # Checksum should be different - assert dataset != test_dataset_v2 + with pytest.raises(PolarisChecksumError): + dataset.verify_checksum() def test_dataset_v2_serialization(test_dataset_v2, tmpdir): @@ -101,7 +108,7 @@ def test_dataset_v2_serialization(test_dataset_v2, tmpdir): def test_dataset_v2_caching(test_dataset_v2, tmpdir): cache_dir = tmpdir.join("cache").strpath test_dataset_v2.cache(cache_dir, verify_checksum=True) - test_dataset_v2.zarr_root_path.startswith(cache_dir) + assert str(test_dataset_v2.zarr_root_path).startswith(cache_dir) def test_dataset_v1_v2_compatibility(test_dataset, tmpdir): @@ -112,11 +119,12 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmpdir): path = tmpdir.join("data/v1v2.zarr") root = zarr.open(path, "w") - for c in df.columns: - root.array(c, data=df[c].values) + root.array("smiles", data=df["smiles"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) + root.array("calc", data=df["calc"].values) + zarr.consolidate_metadata(path) kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"]) - dataset = DatasetV2(**kwargs, zarr_root_path=path) + dataset = DatasetV2(**kwargs, zarr_root_path=str(path)) subset_1 = Subset(dataset=dataset, indices=range(100), input_cols=["smiles"], target_cols=["calc"]) subset_2 = Subset(dataset=dataset, indices=range(100), input_cols=["smiles"], target_cols=["calc"]) From 648421644bc668ed39158c79aa8224cb9bb4e004 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Tue, 27 Aug 2024 16:42:48 -0600 Subject: [PATCH 05/20] Added additional validation --- polaris/benchmark/_base.py | 1 - polaris/dataset/_column.py | 10 +-- polaris/dataset/_dataset.py | 12 ++- polaris/experimental/_dataset_v2.py | 95 ++++++++++++++++++-- tests/test_dataset_v2.py | 132 +++++++++++++++++++++++++++- 5 files changed, 230 insertions(+), 20 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 241e8547..7bbefebd 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -111,7 +111,6 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): def _validate_dataset(cls, v): """ Allows either passing a Dataset object or the kwargs to create one - TODO (cwognum): Allow multiple datasets to be used as part of a benchmark """ if isinstance(v, dict): v = DatasetV1(**v) diff --git a/polaris/dataset/_column.py b/polaris/dataset/_column.py index d533a0fc..2eb5be73 100644 --- a/polaris/dataset/_column.py +++ b/polaris/dataset/_column.py @@ -42,22 +42,22 @@ class ColumnAnnotation(BaseModel): """ is_pointer: bool = False - modality: Union[str, Modality] = Modality.UNKNOWN + modality: Modality = Modality.UNKNOWN description: Optional[str] = None user_attributes: Dict[str, str] = Field(default_factory=dict) - dtype: Union[np.dtype, str, None] = None + dtype: np.dtype | None = None content_type: Union[KnownContentType, str, None] = None model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True) - @field_validator("modality") - def _validate_modality(cls, v, values): + @field_validator("modality", mode="before") + def _validate_modality(cls, v): """Tries to convert a string to the Enum""" if isinstance(v, str): v = Modality[v.upper()] return v - @field_validator("dtype") + @field_validator("dtype", mode="before") def _validate_dtype(cls, v): """Tries to convert a string to the Enum""" if isinstance(v, str): diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 2e92a6e4..ff683779 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -76,13 +76,21 @@ def _validate_table(cls, v): def _validate_model(cls, m: "DatasetV1"): """Verifies some dependencies between properties""" + # NOTE (cwognum): A good chunk of the below code is shared with the DatasetV2 class. + # I tried moving it to the BaseDataset class, but I'm not understanding Pydantic's behavior very well. + # It seems to not always trigger when part of the base class. + # Verify that all annotations are for columns that exist if any(k not in m.columns for k in m.annotations): - raise InvalidDatasetError("There are annotations for columns that do not exist") + raise InvalidDatasetError( + f"There are annotations for columns that do not exist. Columns: {m.columns}. Annotations: {list(m.annotations.keys())}" + ) # Verify that all adapters are for columns that exist if any(k not in m.columns for k in m.default_adapters.keys()): - raise InvalidDatasetError("There are default adapters for columns that do not exist") + raise InvalidDatasetError( + f"There are default adapters for columns that do not exist. Columns: {m.columns}. Adapters: {list(m.annotations.keys())}" + ) # Set a default for missing annotations and convert strings to Modality for c in m.columns: diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index df892e19..7d1c22d1 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -6,23 +6,93 @@ import numpy as np import zarr from loguru import logger -from pydantic import computed_field +from pydantic import computed_field, model_validator from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset +from polaris.dataset._column import ColumnAnnotation from polaris.dataset.zarr._checksum import ZarrFileChecksum, compute_zarr_checksum +from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution +_INDEX_ARRAY_KEY = "__index__" + class DatasetV2(BaseDataset): """Dataset subclass for Polaris""" version: ClassVar[Literal[2]] = 2 + # Redefine this to make it a required field + zarr_root_path: str + + @model_validator(mode="after") + @classmethod + def _validate_model(cls, m: "DatasetV2"): + """Verifies some dependencies between properties""" + + # NOTE (cwognum): A good chunk of the below code is shared with the DatasetV1 class. + # I tried moving it to the BaseDataset class, but I'm not understanding Pydantic's behavior very well. + # It seems to not always trigger when part of the base class. + + # Verify that all annotations are for columns that exist + if any(k not in m.columns for k in m.annotations): + raise InvalidDatasetError( + f"There are annotations for columns that do not exist. Columns: {m.columns}. Annotations: {list(m.annotations.keys())}" + ) + + # Verify that all adapters are for columns that exist + if any(k not in m.columns for k in m.default_adapters.keys()): + raise InvalidDatasetError( + f"There are default adapters for columns that do not exist. Columns: {m.columns}. Adapters: {list(m.annotations.keys())}" + ) + + # Set a default for missing annotations and convert strings to Modality + for c in m.columns: + if c not in m.annotations: + m.annotations[c] = ColumnAnnotation() + if m.annotations[c].is_pointer: + raise InvalidDatasetError("Pointer columns are not supported in DatasetV2") + m.annotations[c].dtype = m.dtypes[c] + + # Since the keys for subgroups are not ordered, we have no easy way to index these groups. + # Any subgroup should therefore have a special array that defines the index for that group. + for group in m.zarr_root.group_keys(): + if _INDEX_ARRAY_KEY not in m.zarr_root[group].array_keys(): + raise InvalidDatasetError(f"Group {group} does not have an index array.") + + index_arr = m.zarr_root[group][_INDEX_ARRAY_KEY] + if len(index_arr) != len(m.zarr_root[group]) - 1: + raise InvalidDatasetError( + f"Length of index array for group {group} does not match the size of the group." + ) + if any(x not in m.zarr_root[group] for x in index_arr): + raise InvalidDatasetError( + f"Keys of index array for group {group} does not match the group members." + ) + + # Check the structure of the Zarr archive + # All arrays or groups in the root should have the same length. + lengths = {len(m.zarr_root[k]) for k in m.zarr_root.array_keys()} + lengths.update({len(m.zarr_root[k][_INDEX_ARRAY_KEY]) for k in m.zarr_root.group_keys()}) + if len(lengths) > 1: + raise InvalidDatasetError( + f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" + ) + return m + + @property + def n_rows(self) -> list: + """Return all row indices for the dataset""" + example = self.zarr_root[self.columns[0]] + if isinstance(example, zarr.Group): + return len(example[_INDEX_ARRAY_KEY]) + return len(example) + @property def rows(self) -> list: """Return all row indices for the dataset""" - return np.arange(len(self.zarr_root[self.columns[0]])) + return np.arange(len(self)) @property def columns(self) -> list: @@ -32,7 +102,12 @@ def columns(self) -> list: @property def dtypes(self) -> dict[str, np.dtype]: """Return the dtype for each of the columns for the dataset""" - return {col: self.zarr_root[col].dtype for col in self.columns} + dtypes = {} + for arr in self.zarr_root.array_keys(): + dtypes[arr] = self.zarr_root[arr].dtype + for group in self.zarr_root.group_keys(): + dtypes[group] = np.dtype(object) + return dtypes @computed_field @property @@ -72,7 +147,11 @@ def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> adapter = adapters.get(col) # Get the data - arr = self.zarr_root[col][row] + group_or_array = self.zarr_data[col] + + if isinstance(group_or_array, zarr.Group): + row = group_or_array[_INDEX_ARRAY_KEY][row] + arr = group_or_array[row] # Adapt the input to the specified format if adapter is not None: @@ -109,12 +188,12 @@ def to_json( destination = Path(destination) destination.mkdir(exist_ok=True, parents=True) - dataset_path = destination / "dataset.json" - new_zarr_root_path = destination / "data.zarr" + dataset_path = str(destination / "dataset.json") + new_zarr_root_path = str(destination / "data.zarr") # Lu: Avoid serilizing and sending None to hub app. serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) - serialized["zarrRootPath"] = str(new_zarr_root_path) + serialized["zarrRootPath"] = new_zarr_root_path # Copy over Zarr data to the destination self._warn_about_remote_zarr = False @@ -136,7 +215,7 @@ def to_json( with fsspec.open(dataset_path, "w") as f: json.dump(serialized, f) - return str(dataset_path) + return dataset_path def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index a5a242bb..45b9753b 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,12 +1,16 @@ +from copy import deepcopy from time import perf_counter import numcodecs import numpy as np import pytest import zarr +from pydantic import ValidationError from polaris.dataset import Subset -from polaris.experimental._dataset_v2 import DatasetV2 +from polaris.dataset._factory import DatasetFactory +from polaris.dataset.converters._pdb import PDBConverter +from polaris.experimental._dataset_v2 import _INDEX_ARRAY_KEY, DatasetV2 from polaris.utils.errors import PolarisChecksumError @@ -120,17 +124,137 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmpdir): root = zarr.open(path, "w") root.array("smiles", data=df["smiles"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) - root.array("calc", data=df["calc"].values) + root.array("iupac", data=df["iupac"].values, dtype=object, object_codec=numcodecs.VLenUTF8()) + for col in set(df.columns) - {"smiles", "iupac"}: + root.array(col, data=df[col].values) zarr.consolidate_metadata(path) kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"]) dataset = DatasetV2(**kwargs, zarr_root_path=str(path)) - subset_1 = Subset(dataset=dataset, indices=range(100), input_cols=["smiles"], target_cols=["calc"]) - subset_2 = Subset(dataset=dataset, indices=range(100), input_cols=["smiles"], target_cols=["calc"]) + subset_1 = Subset(dataset=test_dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"]) + subset_2 = Subset(dataset=dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"]) for idx in range(5): x1, y1 = subset_1[idx] x2, y2 = subset_2[idx] assert x1 == x2 assert y1 == y2 + + +def test_dataset_v2_with_pdbs(pdb_paths, tmpdir): + # The PDB example is interesting because it creates a more complex Zarr archive + # that includes subgroups + zarr_root_path = str(tmpdir.join("pdbs.zarr")) + factory = DatasetFactory(zarr_root_path) + + # Build a V1 dataset + converter = PDBConverter() + factory.register_converter("pdb", converter) + factory.add_from_files(pdb_paths, axis=0) + dataset_v1 = factory.build() + + # Build a V2 dataset based on the V1 dataset + + # Add the magic index column to the Zarr subgroup + root = zarr.open(zarr_root_path, "a") + ordered_keys = [v.split("/")[-1] for v in dataset_v1.table["pdb"].values] + root["pdb"].array(_INDEX_ARRAY_KEY, data=ordered_keys, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_root_path) + + # Update annotations to no longer have pointer columns + annotations = deepcopy(dataset_v1.annotations) + for anno in annotations.values(): + anno.is_pointer = False + + # Create the dataset + dataset_v2 = DatasetV2( + zarr_root_path=zarr_root_path, + annotations=annotations, + default_adapters=dataset_v1.default_adapters, + ) + + assert len(dataset_v1) == len(dataset_v2) + for idx in range(len(dataset_v1)): + pdb_1 = dataset_v1.get_data(idx, "pdb") + pdb_2 = dataset_v2.get_data(idx, "pdb") + assert pdb_1 == pdb_2 + + +def test_dataset_v2_indexing(zarr_archive): + # Create a subgroup with 100 arrays + root = zarr.open(zarr_archive, "a") + subgroup = root.create_group("X") + for i in range(100): + subgroup.array(f"{i}", data=np.array([i] * 100)) + + # Index it in reverse (element 0 is the last element in the array) + indices = [f"{idx}" for idx in range(100)][::-1] + subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive) + + # Create the dataset + dataset = DatasetV2(zarr_root_path=zarr_archive) + + # Check that the dataset is indexed correctly + for idx in range(100): + assert np.array_equal(dataset.get_data(row=idx, col="X"), np.array([99 - idx] * 100)) + + +def test_dataset_v2_validation_index_array(zarr_archive): + root = zarr.open(zarr_archive, "a") + + # Create subgroup that lacks the index array + subgroup = root.create_group("X") + zarr.consolidate_metadata(zarr_archive) + + with pytest.raises(ValidationError, match="does not have an index array"): + DatasetV2(zarr_root_path=zarr_archive) + + indices = [f"{idx}" for idx in range(100)] + indices[-1] = "invalid" + + # Create index array that doesn't match group length (zero arrays in group, but 100 indices) + subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive) + + with pytest.raises(ValidationError, match="Length of index array"): + DatasetV2(zarr_root_path=zarr_archive) + + for i in range(100): + subgroup.array(f"{i}", data=np.random.random(100)) + zarr.consolidate_metadata(zarr_archive) + + # Create index array that has invalid keys (last keys = 'invalid' rather than '99') + with pytest.raises(ValidationError, match="Keys of index array"): + DatasetV2(zarr_root_path=zarr_archive) + + +def test_dataset_v2_validation_consistent_lengths(zarr_archive): + root = zarr.open(zarr_archive, "a") + + # Change the length of one of the arrays + root["A"].append(np.random.random((1, 2048))) + zarr.consolidate_metadata(zarr_archive) + + # Subgroup has a false number of indices + with pytest.raises(ValidationError, match="should have the same length"): + DatasetV2(zarr_root_path=zarr_archive) + + # Make the length of the two arrays equal again + # shouldn't crash + root["B"].append(np.random.random((1, 2048))) + zarr.consolidate_metadata(zarr_archive) + DatasetV2(zarr_root_path=zarr_archive) + + # Create subgroup with inconsistent length + subgroup = root.create_group("X") + for i in range(100): + subgroup.array(f"{i}", data=np.random.random(100)) + indices = [f"{idx}" for idx in range(100)] + subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive) + + # Subgroup has a false number of indices + with pytest.raises(ValidationError, match="should have the same length"): + DatasetV2(zarr_root_path=zarr_archive) From df33bfc50a39edaeba0ec131f03ae0183b500a1b Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Tue, 27 Aug 2024 17:00:44 -0600 Subject: [PATCH 06/20] Improved docs --- polaris/dataset/_base.py | 5 ---- polaris/dataset/_dataset.py | 8 +++---- polaris/experimental/_dataset_v2.py | 37 +++++++++++++++++++++++------ 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index da2de067..e4e8708b 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -47,11 +47,6 @@ class BaseDataset(BaseArtifactModel, ChecksumMixin, abc.ABC): A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple [`BenchmarkSpecification`][polaris.benchmark.BenchmarkSpecification] objects. - Info: Pointer columns - Whereas a `Dataset` contains all information required to construct a dataset, it is not ready yet. - For complex data, such as images, we support storing the content in external blobs of data. - In that case, the table contains _pointers_ to these blobs that are dynamically loaded when needed. - Attributes: default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data for specific columns. diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index ff683779..f5e53ff7 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -24,14 +24,12 @@ class DatasetV1(BaseDataset): - """A Polaris Dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. + """First version of a Polaris Dataset. - At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. - A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple - [`BenchmarkSpecification`][polaris.benchmark.BenchmarkSpecification] objects. + Stores datapoints in a Pandas DataFrame and implements _pointer columns_ to support the storage of XXL data + outside of the DataFrame in a Zarr archive. Info: Pointer columns - Whereas a `Dataset` contains all information required to construct a dataset, it is not ready yet. For complex data, such as images, we support storing the content in external blobs of data. In that case, the table contains _pointers_ to these blobs that are dynamically loaded when needed. diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index 7d1c22d1..4fdda101 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -19,7 +19,25 @@ class DatasetV2(BaseDataset): - """Dataset subclass for Polaris""" + """Second version of a Polaris Dataset. + + This version gets rid of the DataFrame and stores all data in a Zarr archive. + + V1 stored all datapoints in a Pandas DataFrame. Because a DataFrame is always loaded to memory, + this was a bottleneck when the number of data points grew large. Even with the pointer columns, you still + need to load all pointers into memory. V2 therefore switches to a Zarr-only format. + + Info: This feature is still experimental + The DatasetV2 is in active development and will likely undergo breaking changes before release. + + Attributes: + zarr_root_path: The path to the Zarr archive. Different from V1, this is now required. + + For additional meta-data attributes, see the [`BaseDataset`][polaris._dataset.BaseDataset] class. + + Raises: + InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. + """ version: ClassVar[Literal[2]] = 2 @@ -91,7 +109,12 @@ def n_rows(self) -> list: @property def rows(self) -> list: - """Return all row indices for the dataset""" + """Return all row indices for the dataset + + Warning: Memory consumption + This feature is added for completeness sake, but when datasets get large could consume a lot of memory. + E.g. storing a billion indices with np.in64 would consume 8GB of memory. Use with caution. + """ return np.arange(len(self)) @property @@ -128,13 +151,11 @@ def _compute_checksum(self) -> str: return zarr_hash.md5 def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> np.ndarray: - """Since the dataset might contain pointers to external files, data retrieval is more complicated - than just indexing the `table` attribute. This method provides an end-point for seamlessly - accessing the underlying data. + """Indexes the Zarr archive. Args: - row: The row index in the `Dataset.table` attribute - col: The column index in the `Dataset.table` attribute + row: The index of the data to fetch. + col: The label of a direct child of the Zarr root. adapters: The adapters to apply to the data before returning it. If None, will use the default adapters specified for the dataset. @@ -149,6 +170,8 @@ def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> # Get the data group_or_array = self.zarr_data[col] + # If it is a group, there is no deterministic order for the child keys. + # We therefore use a special array that defines the index. if isinstance(group_or_array, zarr.Group): row = group_or_array[_INDEX_ARRAY_KEY][row] arr = group_or_array[row] From 7d4b718144e86e2080d9f987127945efbe13f98b Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Tue, 27 Aug 2024 17:13:23 -0600 Subject: [PATCH 07/20] Fixed some reference errors in the docs --- docs/api/dataset.md | 6 ++++++ polaris/dataset/__init__.py | 4 +--- polaris/dataset/_dataset.py | 5 ++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/api/dataset.md b/docs/api/dataset.md index ec1087e6..2b3cb7c4 100644 --- a/docs/api/dataset.md +++ b/docs/api/dataset.md @@ -2,6 +2,12 @@ options: filters: ["!^_"] +--- + +::: polaris.dataset._base.BaseDataset + options: + filters: ["!^_"] + --- ::: polaris.dataset.ColumnAnnotation diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index ab059de8..d215072d 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -1,11 +1,9 @@ from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality from polaris.dataset._competition_dataset import CompetitionDataset -from polaris.dataset._dataset import DatasetV1 +from polaris.dataset._dataset import DatasetV1 as Dataset from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset -Dataset = DatasetV1 - __all__ = [ "ColumnAnnotation", "Dataset", diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index f5e53ff7..d81b4d79 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -37,7 +37,7 @@ class DatasetV1(BaseDataset): table: The core data-structure, storing data-points in a row-wise manner. Can be specified as either a path to a `.parquet` file or a `pandas.DataFrame`. - For additional meta-data attributes, see the [`BaseDataset`][polaris._dataset.BaseDataset] class. + For additional meta-data attributes, see the [`BaseDataset`][polaris.dataset._base.BaseDataset] class. Raises: InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. @@ -222,8 +222,7 @@ def to_json( Warning: Multiple files Perhaps unintuitive, this method creates multiple files. - 1. `/path/to/destination/dataset.json`: This file can be loaded with - [`Dataset.from_json`][polaris.dataset.Dataset.from_json]. + 1. `/path/to/destination/dataset.json`: This file can be loaded with `Dataset.from_json`. 2. `/path/to/destination/table.parquet`: The `Dataset.table` attribute is saved here. 3. _(Optional)_ `/path/to/destination/data/*`: Any additional blobs of data referenced by the pointer columns will be stored here. From 0484d68aabc0a0586c20f68feb8e96ff032086d8 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Tue, 27 Aug 2024 17:32:15 -0600 Subject: [PATCH 08/20] Disable use of iloc to loc mapping for Dataset V2 --- polaris/dataset/_subset.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 7b11822f..00671e8c 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -62,11 +62,11 @@ class Subset: def __init__( self, dataset: DatasetV1, - indices: List[Union[int, Sequence[int]]], - input_cols: Union[List[str], str], - target_cols: Union[List[str], str], - adapters: Optional[List[Adapter]] = None, - featurization_fn: Optional[Callable] = None, + indices: List[int | Sequence[int]], + input_cols: List[str] | str, + target_cols: List[str] | str, + adapters: List[Adapter] | None = None, + featurization_fn: Callable | None = None, hide_targets: bool = False, ): self.dataset = dataset @@ -77,10 +77,9 @@ def __init__( self._adapters = adapters self._featurization_fn = featurization_fn - # NOTE (cwognum): Note to future self. As we're starting to think about competition-style benchmarks, - # we will likely split up datasets. In that case, this default iloc_to_loc mapping won't work. - # By that time, we should probably be able to overwrite this mapping. - self._iloc_to_loc = self.dataset.rows + # Storing all indices in memory can be memory consuming for XXL datasets. + # This is why we constrain the iloc to loc mapping to be the identity function for Dataset V2. + self._iloc_to_loc = self.dataset.rows if isinstance(self.dataset, DatasetV1) else None # For the iterator implementation self._pointer = 0 @@ -167,9 +166,15 @@ def as_array(self, data_type: Union[Literal["x"], Literal["y"], Literal["xy"]]): # We reset the index of the Pandas Table during Dataset class validation. # We can thus always assume that .iloc[idx] is the same as .loc[idx]. if data_type == "x": - ret = [self._get_single_input(self._iloc_to_loc[idx]) for idx in self.indices] + ret = [ + self._get_single_input(idx if self._iloc_to_loc is None else self._iloc_to_loc[idx]) + for idx in self.indices + ] else: - ret = [self._get_single_output(self._iloc_to_loc[idx]) for idx in self.indices] + ret = [ + self._get_single_output(idx if self._iloc_to_loc is None else self._iloc_to_loc[idx]) + for idx in self.indices + ] if not ((self.is_multi_input and data_type == "x") or (self.is_multi_task and data_type == "y")): # If the target format is not a dict, we can just create the array directly. @@ -202,9 +207,10 @@ def __getitem__(self, item) -> DatapointType: """ idx = self.indices[item] + idx = idx if self._iloc_to_loc is None else self._iloc_to_loc[idx] # Load the input modalities - ins = self._get_single_input(self._iloc_to_loc[idx]) + ins = self._get_single_input(idx) if self._hide_targets: # If we are not allowed to access the targets, we return the inputs only. @@ -212,7 +218,7 @@ def __getitem__(self, item) -> DatapointType: return ins # Retrieve the targets - outs = self._get_single_output(self._iloc_to_loc[idx]) + outs = self._get_single_output(idx) return ins, outs def __iter__(self): From ca76f9d948e31c363feadbc1dbca8cbe253f6aa8 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Tue, 27 Aug 2024 17:53:13 -0600 Subject: [PATCH 09/20] Updated import to prevent circular import --- polaris/dataset/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index d215072d..c4578611 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -1,5 +1,6 @@ from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality from polaris.dataset._competition_dataset import CompetitionDataset +from polaris.dataset._dataset import DatasetV1 from polaris.dataset._dataset import DatasetV1 as Dataset from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset From f0b7c4b5f3c219e6e4bfb84147eb92b3412c3658 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Tue, 27 Aug 2024 17:54:29 -0600 Subject: [PATCH 10/20] Ruff check and format --- polaris/dataset/_column.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polaris/dataset/_column.py b/polaris/dataset/_column.py index 18ff6e8b..8383d78b 100644 --- a/polaris/dataset/_column.py +++ b/polaris/dataset/_column.py @@ -1,5 +1,5 @@ import enum -from typing import Dict, Optional, Union +from typing import Dict, Optional import numpy as np from numpy.typing import DTypeLike @@ -63,7 +63,7 @@ def _validate_content_type(cls, v, values): if isinstance(v, str): v = KnownContentType[v.upper()] return v - + @field_validator("dtype", mode="before") def _validate_dtype(cls, v): """Tries to convert a string to the Enum""" From a968e0be7081dd6134776dcf41c7a80282d2c2f4 Mon Sep 17 00:00:00 2001 From: Andrew Quirke <75542075+Andrewq11@users.noreply.github.com> Date: Sun, 1 Sep 2024 15:46:11 -0400 Subject: [PATCH 11/20] Adding new Zarr manifest generation to DatasetV2 class (#185) * updates for calculating zarr manifests & adding basic tests for it * moving cache_dir assignment to DatasetV1 and DatasetV2 model validators * Updating argument types for parquet utils * Updating argument types for md5 util * fixing DatasetV1 export & dataset model validators * PR feedback updates * Adding test that checks the length of the manifest after update * PR feedback --- polaris/dataset/__init__.py | 2 + polaris/dataset/_base.py | 38 ++++++--------- polaris/dataset/_dataset.py | 6 ++- polaris/experimental/_dataset_v2.py | 32 ++++++------- polaris/utils/v2_manifest.py | 71 +++++++++++++++++++++++++++++ tests/test_dataset_v2.py | 47 +++++++++++-------- 6 files changed, 134 insertions(+), 62 deletions(-) create mode 100644 polaris/utils/v2_manifest.py diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index c4578611..0239cd0b 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -2,6 +2,7 @@ from polaris.dataset._competition_dataset import CompetitionDataset from polaris.dataset._dataset import DatasetV1 from polaris.dataset._dataset import DatasetV1 as Dataset +from polaris.dataset._dataset import DatasetV1 from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset @@ -15,4 +16,5 @@ "DatasetFactory", "create_dataset_from_file", "create_dataset_from_files", + "DatasetV1", ] diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index e4e8708b..638f0ebe 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,8 +1,8 @@ import abc import json -import uuid from pathlib import Path from typing import Dict, List, MutableMapping, Optional, Union +import uuid import fsspec import numpy as np @@ -21,7 +21,7 @@ from polaris._artifact import BaseArtifactModel from polaris.dataset._adapters import Adapter from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum +from polaris.dataset.zarr import MemoryMappedDirectoryStore from polaris.dataset.zarr._utils import load_zarr_group_to_memory from polaris.hub.polarisfs import PolarisFileSystem from polaris.mixins import ChecksumMixin @@ -83,23 +83,9 @@ class BaseDataset(BaseArtifactModel, ChecksumMixin, abc.ABC): # Private attributes _zarr_root: Optional[zarr.Group] = PrivateAttr(None) _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) - _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) _client = PrivateAttr(None) # Optional[PolarisHubClient] _warn_about_remote_zarr: bool = PrivateAttr(True) - @model_validator(mode="after") - @classmethod - def _validate_model(cls, m: "BaseDataset"): - """Verifies some dependencies between properties""" - - # Set the default cache dir if none and make sure it exists - if m.cache_dir is None: - dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) - m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - - m.cache_dir.mkdir(parents=True, exist_ok=True) - return m - @field_validator("default_adapters", mode="before") def _validate_adapters(cls, value): """Validate the adapters""" @@ -117,15 +103,17 @@ def _serialize_paths(value): value = str(value) return value - @computed_field - @property - @abc.abstractmethod - def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: - """ - The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. - If the dataset doesn't use Zarr, this will simply return an empty list. - """ - raise NotImplementedError + @model_validator(mode="after") + def _validate_base_dataset_model(cls, m: "BaseDataset"): + # + # Set the default cache dir if none and make sure it exists + if m.cache_dir is None: + dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + + m.cache_dir.mkdir(parents=True, exist_ok=True) + + return m @property def client(self): diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index d81b4d79..451296d5 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -9,7 +9,7 @@ import zarr from datamol.utils import fs as dmfs from loguru import logger -from pydantic import computed_field, field_validator, model_validator +from pydantic import PrivateAttr, computed_field, field_validator, model_validator from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset @@ -48,6 +48,7 @@ class DatasetV1(BaseDataset): table: pd.DataFrame version: ClassVar[Literal[1]] = 1 + _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) @field_validator("table", mode="before") def _validate_table(cls, v): @@ -71,7 +72,7 @@ def _validate_table(cls, v): @model_validator(mode="after") @classmethod - def _validate_model(cls, m: "DatasetV1"): + def _validate_v1_dataset_model(cls, m: "DatasetV1"): """Verifies some dependencies between properties""" # NOTE (cwognum): A good chunk of the below code is shared with the DatasetV2 class. @@ -103,6 +104,7 @@ def _validate_model(cls, m: "DatasetV1"): raise InvalidDatasetError( "The zarr_root_path should only be specified when there are pointer columns" ) + return m def _compute_checksum(self): diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index 4fdda101..4b5a34a5 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -6,13 +6,14 @@ import numpy as np import zarr from loguru import logger -from pydantic import computed_field, model_validator +from pydantic import model_validator, PrivateAttr from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr._checksum import ZarrFileChecksum, compute_zarr_checksum from polaris.utils.errors import InvalidDatasetError + +from polaris.utils.v2_manifest import calculate_file_md5, generate_zarr_manifest from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution _INDEX_ARRAY_KEY = "__index__" @@ -40,13 +41,14 @@ class DatasetV2(BaseDataset): """ version: ClassVar[Literal[2]] = 2 + _zarr_md5sum_manifest_path: str | None = PrivateAttr(None) # Redefine this to make it a required field zarr_root_path: str @model_validator(mode="after") @classmethod - def _validate_model(cls, m: "DatasetV2"): + def _validate_v2_dataset_model(cls, m: "DatasetV2"): """Verifies some dependencies between properties""" # NOTE (cwognum): A good chunk of the below code is shared with the DatasetV1 class. @@ -97,6 +99,7 @@ def _validate_model(cls, m: "DatasetV2"): raise InvalidDatasetError( f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" ) + return m @property @@ -132,23 +135,18 @@ def dtypes(self) -> dict[str, np.dtype]: dtypes[group] = np.dtype(object) return dtypes - @computed_field @property - def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: - """ - The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. - If the dataset doesn't use Zarr, this will simply return an empty list. - """ - if len(self._zarr_md5sum_manifest) == 0 and not self.has_md5sum: - # The manifest is set as an instance variable - # as a side-effect of the compute_checksum method - self.md5sum = self._compute_checksum() - return self._zarr_md5sum_manifest + def zarr_manifest_path(self) -> str: + if self._zarr_md5sum_manifest_path is None: + zarr_manifest_path = generate_zarr_manifest(self.zarr_root_path, self.cache_dir) + self._zarr_md5sum_manifest_path = zarr_manifest_path + + return self._zarr_md5sum_manifest_path def _compute_checksum(self) -> str: - """Compute the checksum of the dataset.""" - zarr_hash, self._zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) - return zarr_hash.md5 + """Compute the checksum of the Zarr manifest file.""" + manifest_md5 = calculate_file_md5(self.zarr_manifest_path) + return manifest_md5 def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> np.ndarray: """Indexes the Zarr archive. diff --git a/polaris/utils/v2_manifest.py b/polaris/utils/v2_manifest.py new file mode 100644 index 00000000..3b091512 --- /dev/null +++ b/polaris/utils/v2_manifest.py @@ -0,0 +1,71 @@ +from hashlib import md5 +import os +import pyarrow as pa +import pyarrow.parquet as pq + +# PyArrow table schema for the V2 Zarr manifest file +ZARR_MANIFEST_SCHEMA = pa.schema([("path", pa.string()), ("checksum", pa.string())]) + + +def generate_zarr_manifest(zarr_root_path: str, output_dir: str): + """ + Entry point function which triggers the creation of a Zarr manifest for a V2 dataset. + + Parameters: + zarr_root_path: The path to the root of a Zarr archive + output_dir: The path to the directory which will hold the generated manifest + """ + + zarr_manifest_path = f"{output_dir}/zarr_manifest.parquet" + + with pq.ParquetWriter(zarr_manifest_path, ZARR_MANIFEST_SCHEMA) as writer: + recursively_build_manifest(zarr_root_path, writer, zarr_root_path) + + return zarr_manifest_path + + +def recursively_build_manifest(dir_path: str, writer: pq.ParquetWriter, zarr_root_path: str) -> str: + """ + Recursive function that traverses a Zarr archive to build a V2 manifest file. + + Parameters: + dir_path: The path to the current directory being processed in the archive + writer: Writer object for incrementally adding rows to the manifest Parquet file + zarr_root_path: The root path which triggered the first recursive call + """ + + # Get iterator of items located in the directory at `dir_path` + with os.scandir(dir_path) as it: + # + # Loop through directory items in iterator + for entry in it: + if entry.is_dir(): + # + # If item is a directory, recurse into that directory + recursively_build_manifest(entry.path, writer, zarr_root_path) + elif entry.is_file(): + # + # If item is a file, calculate its relative path and chunk checksum. Then, append that + # to the Zarr manifest parquet. + table = pa.Table.from_pydict( + { + "path": [os.path.relpath(entry.path, zarr_root_path)], + "checksum": [calculate_file_md5(entry.path)], + }, + schema=ZARR_MANIFEST_SCHEMA, + ) + writer.write_table(table) + + +def calculate_file_md5(file_path: str): + """Calculates the md5 hash for a file at a given path""" + + md5_hash = md5() + with open(file_path, "rb") as file: + # + # Read the file in chunks to avoid using too much memory for large files + for chunk in iter(lambda: file.read(4096), b""): + md5_hash.update(chunk) + + # Return the hex representation of the digest + return md5_hash.hexdigest() diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 45b9753b..6a241410 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,8 +1,10 @@ from copy import deepcopy +import os from time import perf_counter import numcodecs import numpy as np +import pandas as pd import pytest import zarr from pydantic import ValidationError @@ -11,7 +13,7 @@ from polaris.dataset._factory import DatasetFactory from polaris.dataset.converters._pdb import PDBConverter from polaris.experimental._dataset_v2 import _INDEX_ARRAY_KEY, DatasetV2 -from polaris.utils.errors import PolarisChecksumError +from polaris.utils.v2_manifest import generate_zarr_manifest def test_dataset_v2_get_columns(test_dataset_v2): @@ -84,23 +86,6 @@ def test_dataset_v2_checksum(test_dataset_v2, tmpdir): dataset._md5sum = "invalid" assert dataset != test_dataset_v2 - # (4) With changes, but same hash - # Reset hash - kwargs["md5sum"] = test_dataset_v2.md5sum - - # Copy Zarr data to local - dataset = DatasetV2(**kwargs) - save_dir = tmpdir.join("save_dir") - dataset.to_json(save_dir, load_zarr_from_new_location=True) - - # Make changes to Zarr archive copy - root = zarr.open(dataset.zarr_root_path, "a") - root["A"][0] = np.zeros(2048) - - # Checksum should be different - with pytest.raises(PolarisChecksumError): - dataset.verify_checksum() - def test_dataset_v2_serialization(test_dataset_v2, tmpdir): save_dir = tmpdir.join("save_dir") @@ -258,3 +243,29 @@ def test_dataset_v2_validation_consistent_lengths(zarr_archive): # Subgroup has a false number of indices with pytest.raises(ValidationError, match="should have the same length"): DatasetV2(zarr_root_path=zarr_archive) + + +def test_zarr_manifest(test_dataset_v2): + # Assert the manifest Parquet is created + assert test_dataset_v2.zarr_manifest_path is not None + assert os.path.isfile(test_dataset_v2.zarr_manifest_path) + + # Assert the manifest contains 204 rows (the number "204" is chosen because + # the Zarr archive defined in `conftest.py` contains 204 unique files) + df = pd.read_parquet(test_dataset_v2.zarr_manifest_path) + assert len(df) == 204 + + # Assert the manifest hash is calculated + assert test_dataset_v2.md5sum is not None + + # Add array to Zarr archive to change the number of chunks in the dataset + root = zarr.open(test_dataset_v2.zarr_root_path, "a") + root.array("C", data=np.random.random((100, 2048)), chunks=(1, None)) + + generate_zarr_manifest(test_dataset_v2.zarr_root_path, test_dataset_v2.cache_dir) + + # Get the length of the updated manifest file + post_change_manifest_length = len(pd.read_parquet(test_dataset_v2.zarr_manifest_path)) + + # Ensure Zarr manifest has an additional 100 chunks + 1 array metadata file + assert post_change_manifest_length == 305 From 18bde88056ce19c1b298c4ae9342f7636bbce112 Mon Sep 17 00:00:00 2001 From: Andrew Quirke Date: Sun, 1 Sep 2024 15:49:03 -0400 Subject: [PATCH 12/20] fixing code check test --- polaris/dataset/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index 0239cd0b..e82249cc 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -2,7 +2,6 @@ from polaris.dataset._competition_dataset import CompetitionDataset from polaris.dataset._dataset import DatasetV1 from polaris.dataset._dataset import DatasetV1 as Dataset -from polaris.dataset._dataset import DatasetV1 from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset From 75fe31043de01ed4798bfb93ce1c9588f53df0c8 Mon Sep 17 00:00:00 2001 From: cwognum Date: Tue, 3 Sep 2024 10:47:47 -0400 Subject: [PATCH 13/20] Move code to dataset base class --- polaris/dataset/_base.py | 20 ++++++++++++++++++- polaris/dataset/_dataset.py | 23 ---------------------- polaris/experimental/_dataset_v2.py | 30 ++--------------------------- 3 files changed, 21 insertions(+), 52 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 638f0ebe..f87ab7c6 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,8 +1,8 @@ import abc import json +import uuid from pathlib import Path from typing import Dict, List, MutableMapping, Optional, Union -import uuid import fsspec import numpy as np @@ -105,6 +105,24 @@ def _serialize_paths(value): @model_validator(mode="after") def _validate_base_dataset_model(cls, m: "BaseDataset"): + # Verify that all annotations are for columns that exist + if any(k not in m.columns for k in m.annotations): + raise InvalidDatasetError( + f"There are annotations for columns that do not exist. Columns: {m.columns}. Annotations: {list(m.annotations.keys())}" + ) + + # Verify that all adapters are for columns that exist + if any(k not in m.columns for k in m.default_adapters.keys()): + raise InvalidDatasetError( + f"There are default adapters for columns that do not exist. Columns: {m.columns}. Adapters: {list(m.annotations.keys())}" + ) + + # Set a default for missing annotations and convert strings to Modality + for c in m.columns: + if c not in m.annotations: + m.annotations[c] = ColumnAnnotation() + m.annotations[c].dtype = m.dtypes[c] + # # Set the default cache dir if none and make sure it exists if m.cache_dir is None: diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 451296d5..45bf69f0 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -13,7 +13,6 @@ from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset -from polaris.dataset._column import ColumnAnnotation from polaris.dataset.zarr import ZarrFileChecksum, compute_zarr_checksum from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution @@ -75,28 +74,6 @@ def _validate_table(cls, v): def _validate_v1_dataset_model(cls, m: "DatasetV1"): """Verifies some dependencies between properties""" - # NOTE (cwognum): A good chunk of the below code is shared with the DatasetV2 class. - # I tried moving it to the BaseDataset class, but I'm not understanding Pydantic's behavior very well. - # It seems to not always trigger when part of the base class. - - # Verify that all annotations are for columns that exist - if any(k not in m.columns for k in m.annotations): - raise InvalidDatasetError( - f"There are annotations for columns that do not exist. Columns: {m.columns}. Annotations: {list(m.annotations.keys())}" - ) - - # Verify that all adapters are for columns that exist - if any(k not in m.columns for k in m.default_adapters.keys()): - raise InvalidDatasetError( - f"There are default adapters for columns that do not exist. Columns: {m.columns}. Adapters: {list(m.annotations.keys())}" - ) - - # Set a default for missing annotations and convert strings to Modality - for c in m.columns: - if c not in m.annotations: - m.annotations[c] = ColumnAnnotation() - m.annotations[c].dtype = m.dtypes[c] - has_pointers = any(anno.is_pointer for anno in m.annotations.values()) if has_pointers and m.zarr_root_path is None: raise InvalidDatasetError("A zarr_root_path needs to be specified when there are pointer columns") diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index 4b5a34a5..fc0f6fd3 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -6,15 +6,13 @@ import numpy as np import zarr from loguru import logger -from pydantic import model_validator, PrivateAttr +from pydantic import PrivateAttr, model_validator from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset -from polaris.dataset._column import ColumnAnnotation from polaris.utils.errors import InvalidDatasetError - -from polaris.utils.v2_manifest import calculate_file_md5, generate_zarr_manifest from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution +from polaris.utils.v2_manifest import calculate_file_md5, generate_zarr_manifest _INDEX_ARRAY_KEY = "__index__" @@ -51,30 +49,6 @@ class DatasetV2(BaseDataset): def _validate_v2_dataset_model(cls, m: "DatasetV2"): """Verifies some dependencies between properties""" - # NOTE (cwognum): A good chunk of the below code is shared with the DatasetV1 class. - # I tried moving it to the BaseDataset class, but I'm not understanding Pydantic's behavior very well. - # It seems to not always trigger when part of the base class. - - # Verify that all annotations are for columns that exist - if any(k not in m.columns for k in m.annotations): - raise InvalidDatasetError( - f"There are annotations for columns that do not exist. Columns: {m.columns}. Annotations: {list(m.annotations.keys())}" - ) - - # Verify that all adapters are for columns that exist - if any(k not in m.columns for k in m.default_adapters.keys()): - raise InvalidDatasetError( - f"There are default adapters for columns that do not exist. Columns: {m.columns}. Adapters: {list(m.annotations.keys())}" - ) - - # Set a default for missing annotations and convert strings to Modality - for c in m.columns: - if c not in m.annotations: - m.annotations[c] = ColumnAnnotation() - if m.annotations[c].is_pointer: - raise InvalidDatasetError("Pointer columns are not supported in DatasetV2") - m.annotations[c].dtype = m.dtypes[c] - # Since the keys for subgroups are not ordered, we have no easy way to index these groups. # Any subgroup should therefore have a special array that defines the index for that group. for group in m.zarr_root.group_keys(): From 024e71db7fa267827916f596df2094cb77bf3586 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Wed, 4 Sep 2024 22:00:56 -0400 Subject: [PATCH 14/20] Addressed most feedback on the PR, still need to revisit the __getitem__ method --- polaris/dataset/_base.py | 75 +++++------------- polaris/dataset/_column.py | 24 +----- polaris/dataset/_dataset.py | 59 +++++++++++--- polaris/dataset/_subset.py | 18 ++--- polaris/dataset/converters/_pdb.py | 4 +- polaris/dataset/converters/_sdf.py | 4 +- polaris/dataset/zarr/__init__.py | 8 +- .../zarr/_manifest.py} | 3 +- polaris/experimental/_dataset_v2.py | 79 +++++++++++++++---- polaris/mixins/_checksum.py | 5 +- polaris/utils/errors.py | 7 +- tests/test_dataset.py | 6 +- tests/test_dataset_v2.py | 11 +-- tests/test_evaluate.py | 14 ++-- 14 files changed, 172 insertions(+), 145 deletions(-) rename polaris/{utils/v2_manifest.py => dataset/zarr/_manifest.py} (99%) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index f87ab7c6..05f9a81a 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,12 +1,10 @@ import abc import json -import uuid from pathlib import Path from typing import Dict, List, MutableMapping, Optional, Union import fsspec import numpy as np -import pandas as pd import zarr from loguru import logger from pydantic import ( @@ -24,8 +22,6 @@ from polaris.dataset.zarr import MemoryMappedDirectoryStore from polaris.dataset.zarr._utils import load_zarr_group_to_memory from polaris.hub.polarisfs import PolarisFileSystem -from polaris.mixins import ChecksumMixin -from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( @@ -40,10 +36,12 @@ _CACHE_SUBDIR = "datasets" -class BaseDataset(BaseArtifactModel, ChecksumMixin, abc.ABC): +class BaseDataset(BaseArtifactModel, abc.ABC): """Base data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. - At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. + At its core, a dataset in Polaris can _conceptually_ be thought of as tabular data structure that stores data-points + in a row-wise manner, where each column correspond to a variable associated with that datapoint. + A Dataset can have multiple modalities or targets, can be sparse and can be part of one or multiple [`BenchmarkSpecification`][polaris.benchmark.BenchmarkSpecification] objects. @@ -59,6 +57,7 @@ class BaseDataset(BaseArtifactModel, ChecksumMixin, abc.ABC): source: The data source, e.g. a DOI, Github repo or URI. license: The dataset license. Polaris only supports some Creative Commons licenses. See [`SupportedLicenseType`][polaris.utils.types.SupportedLicenseType] for accepted ID values. curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. + cache_dir: Where the dataset would be cached if you call the `cache()` method. For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. Raises: @@ -78,7 +77,7 @@ class BaseDataset(BaseArtifactModel, ChecksumMixin, abc.ABC): curation_reference: Optional[HttpUrlString] = None # Config - cache_dir: Optional[Path] = None # Where to cache the data to if cache() is called. + cache_dir: Optional[Path] = None # Private attributes _zarr_root: Optional[zarr.Group] = PrivateAttr(None) @@ -123,14 +122,6 @@ def _validate_base_dataset_model(cls, m: "BaseDataset"): m.annotations[c] = ColumnAnnotation() m.annotations[c].dtype = m.dtypes[c] - # - # Set the default cache dir if none and make sure it exists - if m.cache_dir is None: - dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) - m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - - m.cache_dir.mkdir(parents=True, exist_ok=True) - return m @property @@ -166,7 +157,7 @@ def zarr_data(self): return self.zarr_root @property - def zarr_root(self): + def zarr_root(self) -> zarr.Group | None: """Get the zarr Group object corresponding to the root. Opens the zarr archive in read-write mode if it is not already open. @@ -223,13 +214,13 @@ def n_columns(self) -> int: @property @abc.abstractmethod - def rows(self) -> list: + def rows(self) -> list[str | int]: """Return all row indices for the dataset""" raise NotImplementedError @property @abc.abstractmethod - def columns(self) -> list: + def columns(self) -> list[str]: """Return all columns for the dataset""" raise NotImplementedError @@ -261,7 +252,7 @@ def load_to_memory(self): self._zarr_data = load_zarr_group_to_memory(data) @abc.abstractmethod - def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = None) -> np.ndarray: + def get_data(self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly accessing the underlying data. @@ -279,20 +270,18 @@ def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = raise NotImplementedError @abc.abstractmethod - def upload_to_hub( - self, access: Optional[AccessType] = "private", owner: Union[HubOwner, str, None] = None - ): + def upload_to_hub(self, access: AccessType = "private", owner: Union[HubOwner, str, None] = None): """Uploads the dataset to the Polaris Hub.""" raise NotImplementedError @classmethod def from_json(cls, path: str): - """Loads a benchmark from a JSON file. + """Loads a dataset from a JSON file. Overrides the method from the base class to remove the caching dir from the file to load from, as that should be user dependent. Args: - path: Loads a benchmark specification from a JSON file. + path: Loads a dataset specification from a JSON file. """ with fsspec.open(path, "r") as f: data = json.load(f) @@ -321,21 +310,15 @@ def to_json( """ raise NotImplementedError - def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) -> str: + def cache(self, verify_checksum: bool = True) -> str: """Caches the dataset by downloading all additional data for pointer columns to a local directory. Args: - cache_dir: The directory to cache the data to. If not provided, - this will fall back to the `Dataset.cache_dir` attribute verify_checksum: Whether to verify the checksum of the dataset after caching. Returns: The path to the cache directory. """ - - if cache_dir is not None: - self.cache_dir = cache_dir - self.to_json(self.cache_dir, load_zarr_from_new_location=True) if verify_checksum: @@ -343,35 +326,13 @@ def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) - return self.cache_dir - def size(self): - return self.rows, self.n_columns + def size(self) -> tuple[int, int]: + return self.n_rows, self.n_columns + @abc.abstractmethod def __getitem__(self, item): """Allows for indexing the dataset directly""" - ret = self.table.loc[item] - if isinstance(ret, pd.Series): - # Load the data from the pointer columns - - if ret.name in self.table.columns: - # Returning a column, the indices are rows - if self.annotations[ret.name].is_pointer: - ret = np.array([self.get_data(k, ret.name) for k in ret.index]) - - elif len(ret) == self.n_rows: - # Returning a row, the indices are columns - ret = { - k: self.get_data(k, ret.name) if self.annotations[ret.name].is_pointer else ret[k] - for k in ret.index - } - - # Returning a dataframe - if isinstance(ret, pd.DataFrame): - for c in ret.columns: - if self.annotations[c].is_pointer: - ret[c] = [self.get_data(item, c) for item in ret.index] - return ret - - return ret + raise NotImplementedError @abc.abstractmethod def _repr_dict_(self) -> dict: diff --git a/polaris/dataset/_column.py b/polaris/dataset/_column.py index 8383d78b..3ece7f96 100644 --- a/polaris/dataset/_column.py +++ b/polaris/dataset/_column.py @@ -1,5 +1,5 @@ import enum -from typing import Dict, Optional +from typing import Dict, Literal, Optional, TypeAlias import numpy as np from numpy.typing import DTypeLike @@ -18,11 +18,7 @@ class Modality(enum.Enum): IMAGE = "image" -class KnownContentType(enum.Enum): - """Used to specify column's IANA content type in a dataset.""" - - SMILES = "chemical/x-smiles" - PDB = "chemical/x-pdb" +KnownContentType: TypeAlias = Literal["chemical/x-smiles", "chemical/x-pdb"] class ColumnAnnotation(BaseModel): @@ -46,7 +42,7 @@ class ColumnAnnotation(BaseModel): description: Optional[str] = None user_attributes: Dict[str, str] = Field(default_factory=dict) dtype: np.dtype | None = None - content_type: KnownContentType | None = None + content_type: KnownContentType | str | None = None model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True) @@ -57,13 +53,6 @@ def _validate_modality(cls, v): v = Modality[v.upper()] return v - @field_validator("content_type", mode="before") - def _validate_content_type(cls, v, values): - """Tries to convert a string to the Enum""" - if isinstance(v, str): - v = KnownContentType[v.upper()] - return v - @field_validator("dtype", mode="before") def _validate_dtype(cls, v): """Tries to convert a string to the Enum""" @@ -76,13 +65,6 @@ def _serialize_modality(self, v: Modality): """Return the modality as a string, keeping it serializable""" return v.name - @field_serializer("content_type") - def _serialize_content_type(self, v: KnownContentType): - """Return the content_type as a string, keeping it serializable""" - if v is not None: - v = v.name - return v - @field_serializer("dtype") def _serialize_dtype(self, v: Optional[DTypeLike]): """Return the dtype as a string, keeping it serializable""" diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 45bf69f0..90232377 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,7 +1,8 @@ import json +import uuid from hashlib import md5 from pathlib import Path -from typing import ClassVar, List, Literal, Optional, Tuple, Union +from typing import ClassVar, List, Literal, Union import fsspec import numpy as np @@ -12,8 +13,10 @@ from pydantic import PrivateAttr, computed_field, field_validator, model_validator from polaris.dataset._adapters import Adapter -from polaris.dataset._base import BaseDataset +from polaris.dataset._base import _CACHE_SUBDIR, BaseDataset from polaris.dataset.zarr import ZarrFileChecksum, compute_zarr_checksum +from polaris.mixins._checksum import ChecksumMixin +from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution @@ -22,7 +25,7 @@ _INDEX_SEP = "#" -class DatasetV1(BaseDataset): +class DatasetV1(BaseDataset, ChecksumMixin): """First version of a Polaris Dataset. Stores datapoints in a Pandas DataFrame and implements _pointer columns_ to support the storage of XXL data @@ -82,9 +85,15 @@ def _validate_v1_dataset_model(cls, m: "DatasetV1"): "The zarr_root_path should only be specified when there are pointer columns" ) + # Set the default cache dir if none and make sure it exists + if m.cache_dir is None: + dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + m.cache_dir.mkdir(parents=True, exist_ok=True) + return m - def _compute_checksum(self): + def _compute_checksum(self) -> str: """Computes a hash of the dataset. This is meant to uniquely identify the dataset and can be used to verify the version. @@ -125,12 +134,12 @@ def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: return self._zarr_md5sum_manifest @property - def rows(self) -> list: + def rows(self) -> list[str | int]: """Return all row indices for the dataset""" return self.table.index.tolist() @property - def columns(self) -> list: + def columns(self) -> list[str]: """Return all columns for the dataset""" return self.table.columns.tolist() @@ -139,7 +148,7 @@ def dtypes(self) -> dict[str, np.dtype]: """Return the dtype for each of the columns for the dataset""" return {col: self.table[col].dtype for col in self.columns} - def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = None) -> np.ndarray: + def get_data(self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly accessing the underlying data. @@ -156,7 +165,8 @@ def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = """ # Fetch adapters for dataset and a given column - adapters = adapters or self.default_adapters + # Partially override if the adapters parameter is specified. + adapters = {**self.default_adapters, **(adapters or {})} adapter = adapters.get(col) # If not a pointer, return it here. Apply adapter if specified. @@ -180,9 +190,7 @@ def get_data(self, row: str | int, col: str, adapters: Optional[List[Adapter]] = return arr - def upload_to_hub( - self, access: Optional[AccessType] = "private", owner: Union[HubOwner, str, None] = None - ): + def upload_to_hub(self, access: AccessType = "private", owner: Union[HubOwner, str, None] = None): """ Very light, convenient wrapper around the [`PolarisHubClient.upload_dataset`][polaris.hub.client.PolarisHubClient.upload_dataset] method. @@ -255,7 +263,7 @@ def to_json( return dataset_path - def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: + def _split_index_from_path(self, path: str) -> tuple[str, int | None]: """ Paths can have an additional index appended to them. This extracts that index from the path. @@ -273,6 +281,33 @@ def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: raise ValueError(f"Invalid index format: {index}") return path, index + def __getitem__(self, item): + """Allows for indexing the dataset directly""" + ret = self.table.loc[item] + if isinstance(ret, pd.Series): + # Load the data from the pointer columns + + if ret.name in self.table.columns: + # Returning a column, the indices are rows + if self.annotations[ret.name].is_pointer: + ret = np.array([self.get_data(k, ret.name) for k in ret.index]) + + elif len(ret) == self.n_rows: + # Returning a row, the indices are columns + ret = { + k: self.get_data(k, ret.name) if self.annotations[ret.name].is_pointer else ret[k] + for k in ret.index + } + + # Returning a dataframe + if isinstance(ret, pd.DataFrame): + for c in ret.columns: + if self.annotations[c].is_pointer: + ret[c] = [self.get_data(item, c) for item in ret.index] + return ret + + return ret + def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"table", "zarr_md5sum_manifest"}) diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 00671e8c..05cbee80 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -79,7 +79,11 @@ def __init__( # Storing all indices in memory can be memory consuming for XXL datasets. # This is why we constrain the iloc to loc mapping to be the identity function for Dataset V2. - self._iloc_to_loc = self.dataset.rows if isinstance(self.dataset, DatasetV1) else None + match self.dataset: + case DatasetV1(): + self._iloc_to_loc = lambda idx: self.dataset.rows[idx] + case _: + self._iloc_to_loc = lambda idx: idx # For the iterator implementation self._pointer = 0 @@ -166,15 +170,9 @@ def as_array(self, data_type: Union[Literal["x"], Literal["y"], Literal["xy"]]): # We reset the index of the Pandas Table during Dataset class validation. # We can thus always assume that .iloc[idx] is the same as .loc[idx]. if data_type == "x": - ret = [ - self._get_single_input(idx if self._iloc_to_loc is None else self._iloc_to_loc[idx]) - for idx in self.indices - ] + ret = [self._get_single_input(self._iloc_to_loc(idx)) for idx in self.indices] else: - ret = [ - self._get_single_output(idx if self._iloc_to_loc is None else self._iloc_to_loc[idx]) - for idx in self.indices - ] + ret = [self._get_single_output(self._iloc_to_loc(idx)) for idx in self.indices] if not ((self.is_multi_input and data_type == "x") or (self.is_multi_task and data_type == "y")): # If the target format is not a dict, we can just create the array directly. @@ -207,7 +205,7 @@ def __getitem__(self, item) -> DatapointType: """ idx = self.indices[item] - idx = idx if self._iloc_to_loc is None else self._iloc_to_loc[idx] + idx = self._iloc_to_loc(idx) # Load the input modalities ins = self._get_single_input(idx) diff --git a/polaris/dataset/converters/_pdb.py b/polaris/dataset/converters/_pdb.py index ac0a0e15..5694bd52 100644 --- a/polaris/dataset/converters/_pdb.py +++ b/polaris/dataset/converters/_pdb.py @@ -7,7 +7,7 @@ import zarr from fastpdb import struc -from polaris.dataset import ColumnAnnotation, Modality, KnownContentType +from polaris.dataset import ColumnAnnotation, Modality from polaris.dataset._adapters import Adapter from polaris.dataset.converters._base import Converter, FactoryProduct @@ -190,7 +190,7 @@ def convert(self, path, factory: "DatasetFactory", append: bool = False) -> Fact # Set the annotations annotations = { self.pdb_column: ColumnAnnotation( - is_pointer=True, modality=Modality.PROTEIN_3D, content_type=KnownContentType.PDB + is_pointer=True, modality=Modality.PROTEIN_3D, content_type="chemical/x-pdb" ) } diff --git a/polaris/dataset/converters/_sdf.py b/polaris/dataset/converters/_sdf.py index 5a993fb7..2cde7acb 100644 --- a/polaris/dataset/converters/_sdf.py +++ b/polaris/dataset/converters/_sdf.py @@ -5,7 +5,7 @@ import pandas as pd from rdkit import Chem -from polaris.dataset import ColumnAnnotation, Modality, KnownContentType +from polaris.dataset import ColumnAnnotation, Modality from polaris.dataset._adapters import Adapter from polaris.dataset.converters._base import Converter, FactoryProduct @@ -149,7 +149,7 @@ def _get_name(mol: dm.Mol): annotations = {self.mol_column: ColumnAnnotation(is_pointer=True, modality=Modality.MOLECULE_3D)} if self.smiles_column is not None: annotations[self.smiles_column] = ColumnAnnotation( - modality=Modality.MOLECULE, content_type=KnownContentType.SMILES + modality=Modality.MOLECULE, content_type="chemical/x-smiles" ) # Return the dataframe and the annotations diff --git a/polaris/dataset/zarr/__init__.py b/polaris/dataset/zarr/__init__.py index 57f500ed..b936b607 100644 --- a/polaris/dataset/zarr/__init__.py +++ b/polaris/dataset/zarr/__init__.py @@ -1,4 +1,10 @@ from ._checksum import ZarrFileChecksum, compute_zarr_checksum +from ._manifest import generate_zarr_manifest from ._memmap import MemoryMappedDirectoryStore -__all__ = ["MemoryMappedDirectoryStore", "compute_zarr_checksum", "ZarrFileChecksum"] +__all__ = [ + "MemoryMappedDirectoryStore", + "compute_zarr_checksum", + "ZarrFileChecksum", + "generate_zarr_manifest", +] diff --git a/polaris/utils/v2_manifest.py b/polaris/dataset/zarr/_manifest.py similarity index 99% rename from polaris/utils/v2_manifest.py rename to polaris/dataset/zarr/_manifest.py index 3b091512..b7578ce4 100644 --- a/polaris/utils/v2_manifest.py +++ b/polaris/dataset/zarr/_manifest.py @@ -1,5 +1,6 @@ -from hashlib import md5 import os +from hashlib import md5 + import pyarrow as pa import pyarrow.parquet as pq diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index fc0f6fd3..75bef30f 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -1,18 +1,21 @@ import json +import re +import uuid from pathlib import Path -from typing import ClassVar, List, Literal, Optional +from typing import ClassVar, Literal import fsspec import numpy as np import zarr from loguru import logger -from pydantic import PrivateAttr, model_validator +from pydantic import PrivateAttr, computed_field, model_validator from polaris.dataset._adapters import Adapter -from polaris.dataset._base import BaseDataset +from polaris.dataset._base import _CACHE_SUBDIR, BaseDataset +from polaris.dataset.zarr._manifest import calculate_file_md5, generate_zarr_manifest +from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, HubOwner, ZarrConflictResolution -from polaris.utils.v2_manifest import calculate_file_md5, generate_zarr_manifest _INDEX_ARRAY_KEY = "__index__" @@ -40,6 +43,7 @@ class DatasetV2(BaseDataset): version: ClassVar[Literal[2]] = 2 _zarr_md5sum_manifest_path: str | None = PrivateAttr(None) + _md5sum: str | None = PrivateAttr(None) # Redefine this to make it a required field zarr_root_path: str @@ -74,10 +78,16 @@ def _validate_v2_dataset_model(cls, m: "DatasetV2"): f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" ) + # Set the default cache dir if none and make sure it exists + if m.cache_dir is None: + dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + m.cache_dir.mkdir(parents=True, exist_ok=True) + return m @property - def n_rows(self) -> list: + def n_rows(self) -> int: """Return all row indices for the dataset""" example = self.zarr_root[self.columns[0]] if isinstance(example, zarr.Group): @@ -85,17 +95,17 @@ def n_rows(self) -> list: return len(example) @property - def rows(self) -> list: + def rows(self) -> np.ndarray[int]: """Return all row indices for the dataset Warning: Memory consumption This feature is added for completeness sake, but when datasets get large could consume a lot of memory. E.g. storing a billion indices with np.in64 would consume 8GB of memory. Use with caution. """ - return np.arange(len(self)) + return np.arange(len(self), dtype=int) @property - def columns(self) -> list: + def columns(self) -> list[str]: """Return all columns for the dataset""" return list(self.zarr_root.keys()) @@ -117,12 +127,32 @@ def zarr_manifest_path(self) -> str: return self._zarr_md5sum_manifest_path - def _compute_checksum(self) -> str: - """Compute the checksum of the Zarr manifest file.""" - manifest_md5 = calculate_file_md5(self.zarr_manifest_path) - return manifest_md5 + @computed_field + @property + def md5sum(self) -> str: + """ + Lazily compute the checksum once needed. + + The checksum of the DatasetV2 is computed from the Zarr Manifest and is _not_ deterministic. + """ + if not self.has_md5sum: + logger.info("Computing the checksum. This can be slow for large datasets.") + self.md5sum = calculate_file_md5(self.zarr_manifest_path) + return self._md5sum + + @md5sum.setter + def md5sum(self, value: str): + """Set the checksum.""" + if not re.fullmatch(r"^[a-f0-9]{32}$", value): + raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") + self._md5sum = value + + @property + def has_md5sum(self) -> bool: + """Whether the md5sum for this class has been computed and stored.""" + return self._md5sum is not None - def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> np.ndarray: + def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray: """Indexes the Zarr archive. Args: @@ -136,7 +166,8 @@ def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> the content of the referenced file is loaded to memory. """ # Fetch adapters for dataset and a given column - adapters = adapters or self.default_adapters + # Partially override if the adapters parameter is specified. + adapters = {**self.default_adapters, **(adapters or {})} adapter = adapters.get(col) # Get the data @@ -154,7 +185,7 @@ def get_data(self, row: int, col: str, adapters: List[Adapter] | None = None) -> return arr - def upload_to_hub(self, access: Optional[AccessType] = "private", owner: HubOwner | str | None = None): + def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | None = None): """Uploads the dataset to the Polaris Hub.""" # NOTE (cwognum): Leaving this for a later PR, because I want @@ -212,6 +243,24 @@ def to_json( json.dump(serialized, f) return dataset_path + def cache(self) -> str: + """Caches the dataset by downloading all additional data for pointer columns to a local directory. + + Args: + verify_checksum: Whether to verify the checksum of the dataset after caching. + + Returns: + The path to the cache directory. + """ + # NOTE (cwognum): We don't support a deterministic checksum for the Dataset V2 yet, + # so verification doesn't make sense. See also: + # https://github.com/polaris-hub/polaris/issues/188 + super().cache(verify_checksum=False) + + def __getitem__(self, item): + """Allows for indexing the dataset directly""" + raise NotImplementedError + def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"zarr_md5sum_manifest"}) diff --git a/polaris/mixins/_checksum.py b/polaris/mixins/_checksum.py index 6991e7f5..8fccac35 100644 --- a/polaris/mixins/_checksum.py +++ b/polaris/mixins/_checksum.py @@ -4,6 +4,8 @@ from loguru import logger from pydantic import BaseModel, PrivateAttr, computed_field +from polaris.utils.errors import PolarisChecksumError + class ChecksumMixin(BaseModel, abc.ABC): """ @@ -64,9 +66,6 @@ def verify_checksum(self, md5sum: str | None = None): self.md5sum = self._compute_checksum() if self.md5sum != md5sum: - # Imported here to prevent circular import - from polaris.utils.errors import PolarisChecksumError - raise PolarisChecksumError( f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" ) diff --git a/polaris/utils/errors.py b/polaris/utils/errors.py index 3e800847..ad7726bf 100644 --- a/polaris/utils/errors.py +++ b/polaris/utils/errors.py @@ -1,7 +1,5 @@ from httpx import Response -from polaris.mixins._format_text import FormattingMixin # Imported with full path to avoid circular import - class InvalidDatasetError(ValueError): pass @@ -34,7 +32,7 @@ class InvalidZarrChecksum(Exception): pass -class PolarisHubError(Exception, FormattingMixin): +class PolarisHubError(Exception): def __init__(self, message: str = "", response: Response | None = None): prefix = "The request to the Polaris Hub failed." @@ -50,7 +48,6 @@ def __init__(self, response: Response | None = None): "You are not logged in to Polaris or your login has expired. " "You can use the Polaris CLI to easily authenticate yourself again with `polaris login --overwrite`." ) - message = self.format(message, [self.BOLD, self.YELLOW]) super().__init__(message, response) @@ -60,7 +57,6 @@ def __init__(self, response: Response | None = None): "Note: If you can confirm that you are authorized to perform this action, " "please call 'polaris login --overwrite' and try again. If the issue persists, please reach out to the Polaris team for support." ) - message = self.format(message, [self.BOLD, self.YELLOW]) super().__init__(message, response) @@ -70,5 +66,4 @@ def __init__(self, response: Response | None = None): "Note: If this artifact exists and you can confirm that you are authorized to retrieve it, " "please call 'polaris login --overwrite' and try again. If the issue persists, please reach out to the Polaris team for support." ) - message = self.format(message, [self.BOLD, self.YELLOW]) super().__init__(message, response) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ff409eb1..7b22f0b3 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -30,7 +30,8 @@ def test_load_data(tmp_path, with_slice, with_caching): dataset = DatasetV1(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path) if with_caching: - dataset.cache(fs.join(tmpdir, "cache")) + dataset.cache_dir = fs.join(tmpdir, "cache") + dataset.cache() data = dataset.get_data(row=0, col="A") @@ -131,7 +132,8 @@ def test_dataset_caching(zarr_archive, tmpdir): cached_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original2")) assert original_dataset == cached_dataset - cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath, verify_checksum=True) + cached_dataset.cache_dir = tmpdir.join("cached").strpath + cache_dir = cached_dataset.cache(verify_checksum=True) assert cached_dataset.zarr_root_path.startswith(cache_dir) assert cached_dataset == original_dataset diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index 6a241410..a23719b7 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,5 +1,5 @@ -from copy import deepcopy import os +from copy import deepcopy from time import perf_counter import numcodecs @@ -12,15 +12,15 @@ from polaris.dataset import Subset from polaris.dataset._factory import DatasetFactory from polaris.dataset.converters._pdb import PDBConverter +from polaris.dataset.zarr._manifest import generate_zarr_manifest from polaris.experimental._dataset_v2 import _INDEX_ARRAY_KEY, DatasetV2 -from polaris.utils.v2_manifest import generate_zarr_manifest def test_dataset_v2_get_columns(test_dataset_v2): assert set(test_dataset_v2.columns) == {"A", "B"} -def test_dataset_v2_get_rows(test_dataset_v2, zarr_archive): +def test_dataset_v2_get_rows(test_dataset_v2): assert set(test_dataset_v2.rows) == set(range(100)) @@ -65,7 +65,7 @@ def test_dataset_v2_load_to_memory(test_dataset_v2): assert d2 < d1 -def test_dataset_v2_checksum(test_dataset_v2, tmpdir): +def test_dataset_v2_checksum(test_dataset_v2): # Make sure the `md5sum` is part of the model dump even if not initiated yet. # This is important for uploads to the Hub. assert test_dataset_v2._md5sum is None @@ -96,7 +96,8 @@ def test_dataset_v2_serialization(test_dataset_v2, tmpdir): def test_dataset_v2_caching(test_dataset_v2, tmpdir): cache_dir = tmpdir.join("cache").strpath - test_dataset_v2.cache(cache_dir, verify_checksum=True) + test_dataset_v2.cache_dir = cache_dir + test_dataset_v2.cache() assert str(test_dataset_v2.zarr_root_path).startswith(cache_dir) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 8440bdbb..f1e76eee 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -45,7 +45,7 @@ def test_result_to_json(tmpdir: str, test_user_owner: HubOwner): assert po.__version__ == result.polaris_version -def test_metrics_singletask_reg(tmpdir: str, test_single_task_benchmark: SingleTaskBenchmarkSpecification): +def test_metrics_singletask_reg(test_single_task_benchmark: SingleTaskBenchmarkSpecification): _, test = test_single_task_benchmark.get_train_test_split() predictions = np.random.random(size=test.inputs.shape[0]) result = test_single_task_benchmark.evaluate(predictions) @@ -60,7 +60,7 @@ def test_metrics_singletask_reg(tmpdir: str, test_single_task_benchmark: SingleT assert metric in result.results.Metric.tolist() -def test_metrics_multitask_reg(tmpdir: str, test_multi_task_benchmark: MultiTaskBenchmarkSpecification): +def test_metrics_multitask_reg(test_multi_task_benchmark: MultiTaskBenchmarkSpecification): train, test = test_multi_task_benchmark.get_train_test_split() predictions = { target_col: np.random.random(size=test.inputs.shape[0]) for target_col in train.target_cols @@ -70,9 +70,7 @@ def test_metrics_multitask_reg(tmpdir: str, test_multi_task_benchmark: MultiTask assert metric in result.results.Metric.tolist() -def test_metrics_singletask_clf( - tmpdir: str, test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification -): +def test_metrics_singletask_clf(test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification): _, test = test_single_task_benchmark_clf.get_train_test_split() predictions = np.random.randint(2, size=test.inputs.shape[0]) probabilities = np.random.uniform(size=test.inputs.shape[0]) @@ -82,7 +80,7 @@ def test_metrics_singletask_clf( def test_metrics_singletask_multicls_clf( - tmpdir: str, test_single_task_benchmark_multi_clf: SingleTaskBenchmarkSpecification + test_single_task_benchmark_multi_clf: SingleTaskBenchmarkSpecification, ): _, test = test_single_task_benchmark_multi_clf.get_train_test_split() predictions = np.random.randint(3, size=test.inputs.shape[0]) @@ -93,7 +91,7 @@ def test_metrics_singletask_multicls_clf( assert metric in result.results.Metric.tolist() -def test_metrics_multitask_clf(tmpdir: str, test_multi_task_benchmark_clf: MultiTaskBenchmarkSpecification): +def test_metrics_multitask_clf(test_multi_task_benchmark_clf: MultiTaskBenchmarkSpecification): train, test = test_multi_task_benchmark_clf.get_train_test_split() predictions = { target_col: np.random.randint(2, size=test.inputs.shape[0]) for target_col in train.target_cols @@ -151,7 +149,7 @@ def test_absolute_average_fold_error(): def test_metric_y_types( - tmpdir: str, test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification, test_data: DatasetV1 + test_single_task_benchmark_clf: SingleTaskBenchmarkSpecification, test_data: DatasetV1 ): # here we use train split for testing purpose. _, test = test_single_task_benchmark_clf.get_train_test_split() From 13fa9f16906c3c9e1c03708e716a8c9d56988ec5 Mon Sep 17 00:00:00 2001 From: cwognum Date: Thu, 5 Sep 2024 18:56:16 -0400 Subject: [PATCH 15/20] Worked on the __getitem__ method --- docs/quickstart.md | 4 +-- polaris/benchmark/_base.py | 5 ++-- polaris/dataset/_base.py | 19 ++++++++++---- polaris/dataset/_dataset.py | 33 +++--------------------- polaris/experimental/_dataset_v2.py | 8 ++---- polaris/utils/types.py | 13 ++++++++++ tests/test_dataset.py | 39 +++++++++++++++++++++++++++++ tests/test_dataset_v2.py | 19 ++++++++++++++ 8 files changed, 96 insertions(+), 44 deletions(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index f2faa180..90c6a9e8 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -82,8 +82,8 @@ dataset.get_data( # Or, similarly: dataset[dataset.rows[0], dataset.columns[0]] -# Get the first 10 rows in memory -dataset[:10] +# Get an entire row +dataset[0] ``` ## Core concepts diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 7bbefebd..cac3aa4c 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -233,11 +233,12 @@ def _validate_target_types(cls, v, info: ValidationInfo): for target in target_cols: if target not in v: - val = dataset[:, target] + val = dataset.table.loc[:, target] # Non numeric columns can be targets (e.g. prediction molecular reactions), # but in that case we currently don't infer the target type. if not np.issubdtype(val.dtype, np.number): + v[target] = None continue # remove the nans for mutiple task dataset when the table is sparse @@ -341,7 +342,7 @@ def n_classes(self) -> dict[str, int]: """The number of classes for each of the target columns.""" n_classes = {} for target in self.target_cols: - target_type = self.target_types[target] + target_type = self.target_types.get(target) if target_type is None or target_type == TargetType.REGRESSION: continue # TODO: Don't use table attribute diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 05f9a81a..227e3868 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,7 +1,7 @@ import abc import json from pathlib import Path -from typing import Dict, List, MutableMapping, Optional, Union +from typing import Any, Dict, List, MutableMapping, Optional, Union import fsspec import numpy as np @@ -26,6 +26,7 @@ from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( AccessType, + DatasetIndex, HttpUrlString, HubOwner, SupportedLicenseType, @@ -252,7 +253,9 @@ def load_to_memory(self): self._zarr_data = load_zarr_group_to_memory(data) @abc.abstractmethod - def get_data(self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray: + def get_data( + self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None + ) -> np.ndarray | Any: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly accessing the underlying data. @@ -329,10 +332,16 @@ def cache(self, verify_checksum: bool = True) -> str: def size(self) -> tuple[int, int]: return self.n_rows, self.n_columns - @abc.abstractmethod - def __getitem__(self, item): + def __getitem__(self, item: DatasetIndex) -> Any | np.ndarray | dict[str, np.ndarray]: """Allows for indexing the dataset directly""" - raise NotImplementedError + + # If a tuple, we assume it's the row and column index pair + if isinstance(item, tuple): + row, col = item + return self.get_data(row, col) + + # Otherwise, we assume you're indexing the row + return {col: self.get_data(item, col) for col in self.columns} @abc.abstractmethod def _repr_dict_(self) -> dict: diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 90232377..255bcdcf 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -2,7 +2,7 @@ import uuid from hashlib import md5 from pathlib import Path -from typing import ClassVar, List, Literal, Union +from typing import Any, ClassVar, List, Literal, Union import fsspec import numpy as np @@ -148,7 +148,9 @@ def dtypes(self) -> dict[str, np.dtype]: """Return the dtype for each of the columns for the dataset""" return {col: self.table[col].dtype for col in self.columns} - def get_data(self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray: + def get_data( + self, row: str | int, col: str, adapters: dict[str, Adapter] | None = None + ) -> np.ndarray | Any: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly accessing the underlying data. @@ -281,33 +283,6 @@ def _split_index_from_path(self, path: str) -> tuple[str, int | None]: raise ValueError(f"Invalid index format: {index}") return path, index - def __getitem__(self, item): - """Allows for indexing the dataset directly""" - ret = self.table.loc[item] - if isinstance(ret, pd.Series): - # Load the data from the pointer columns - - if ret.name in self.table.columns: - # Returning a column, the indices are rows - if self.annotations[ret.name].is_pointer: - ret = np.array([self.get_data(k, ret.name) for k in ret.index]) - - elif len(ret) == self.n_rows: - # Returning a row, the indices are columns - ret = { - k: self.get_data(k, ret.name) if self.annotations[ret.name].is_pointer else ret[k] - for k in ret.index - } - - # Returning a dataframe - if isinstance(ret, pd.DataFrame): - for c in ret.columns: - if self.annotations[c].is_pointer: - ret[c] = [self.get_data(item, c) for item in ret.index] - return ret - - return ret - def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"table", "zarr_md5sum_manifest"}) diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index 75bef30f..cdda80f1 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -2,7 +2,7 @@ import re import uuid from pathlib import Path -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal import fsspec import numpy as np @@ -152,7 +152,7 @@ def has_md5sum(self) -> bool: """Whether the md5sum for this class has been computed and stored.""" return self._md5sum is not None - def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray: + def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray | Any: """Indexes the Zarr archive. Args: @@ -257,10 +257,6 @@ def cache(self) -> str: # https://github.com/polaris-hub/polaris/issues/188 super().cache(verify_checksum=False) - def __getitem__(self, item): - """Allows for indexing the dataset directly""" - raise NotImplementedError - def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"zarr_md5sum_manifest"}) diff --git a/polaris/utils/types.py b/polaris/utils/types.py index e1cac444..b34b6e55 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -118,6 +118,19 @@ Type to specify which action to take to verify the data integrity of an artifact through a checksum. """ +RowIndex: TypeAlias = int | str +ColumnIndex: TypeAlias = str +DatasetIndex: TypeAlias = RowIndex | tuple[RowIndex, ColumnIndex] +""" +To index a dataset using square brackets, we have a few options: + +- A single row, e.g. dataset[0] +- Specify a specific value, e.g. dataset[0, "col1"] + +There are more exciting options we could implement, such as slicing, +but this gets complex. +""" + class HubOwner(BaseModel): """An owner of an artifact on the Polaris Hub diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 7b22f0b3..5319bb7e 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -179,3 +179,42 @@ def test_checksum_verification(test_dataset): test_dataset.md5sum = "0" * 32 with pytest.raises(ValueError): test_dataset.verify_checksum() + + +def test_dataset__get_item__(): + """Test the __getitem__() interface for the dataset.""" + + table = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}, index=["X", "Y", "Z"]) + dataset = DatasetV1(table=table) + + # Get a specific cell + assert dataset["X", "A"] == 1 + assert dataset["X", "B"] == 4 + assert dataset["Y", "A"] == 2 + assert dataset["Y", "B"] == 5 + assert dataset["Z", "A"] == 3 + assert dataset["Z", "B"] == 6 + + # Get a row + assert dataset["X"] == {"A": 1, "B": 4, "C": 7} + assert dataset["Y"] == {"A": 2, "B": 5, "C": 8} + assert dataset["Z"] == {"A": 3, "B": 6, "C": 9} + + +def test_dataset__get_item__with_pointer_columns(zarr_archive, tmpdir): + """Test the __getitem__() interface for a dataset with pointer columns (i.e. part of the data stored in Zarr).""" + + dataset = create_dataset_from_file(zarr_archive, tmpdir.join("data")) + root = zarr.open(zarr_archive) + + # Get a specific cell + assert np.array_equal(dataset[0, "A"], root["A"][0, :]) + + # Get a specific row + def _check_row_equality(d1, d2): + assert len(d1) == len(d2) + for k in d1: + assert np.array_equal(d1[k], d2[k]) + + _check_row_equality(dataset[0], {"A": root["A"][0, :], "B": root["B"][0, :]}) + _check_row_equality(dataset[10], {"A": root["A"][10, :], "B": root["B"][10, :]}) diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index a23719b7..c1ef2b5f 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -270,3 +270,22 @@ def test_zarr_manifest(test_dataset_v2): # Ensure Zarr manifest has an additional 100 chunks + 1 array metadata file assert post_change_manifest_length == 305 + + +def test_dataset_v2__get_item__(test_dataset_v2, zarr_archive): + """Test the __getitem__() interface for the dataset V2.""" + + # Ground truth + root = zarr.open(zarr_archive) + + # Get a specific cell + assert np.array_equal(test_dataset_v2[0, "A"], root["A"][0, :]) + + # Get a specific row + def _check_row_equality(d1, d2): + assert len(d1) == len(d2) + for k in d1: + assert np.array_equal(d1[k], d2[k]) + + _check_row_equality(test_dataset_v2[0], {"A": root["A"][0, :], "B": root["B"][0, :]}) + _check_row_equality(test_dataset_v2[10], {"A": root["A"][10, :], "B": root["B"][10, :]}) From 0e04c1f9c09822dd8407551b73d705614ee6b283 Mon Sep 17 00:00:00 2001 From: cwognum Date: Fri, 6 Sep 2024 12:17:38 -0400 Subject: [PATCH 16/20] Address special case of pointer columns --- polaris/benchmark/_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index cac3aa4c..8f9edbd1 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -233,6 +233,14 @@ def _validate_target_types(cls, v, info: ValidationInfo): for target in target_cols: if target not in v: + # Skip inferring the target type for pointer columns. + # This would be complex to implement properly. + # For these columns, dataset creators can still manually specify the target type. + anno = dataset.annotations.get(target) + if anno is not None and anno.is_pointer: + v[target] = None + continue + val = dataset.table.loc[:, target] # Non numeric columns can be targets (e.g. prediction molecular reactions), From d3a18d59484351ef1cad841edc2e06b7fd327802 Mon Sep 17 00:00:00 2001 From: cwognum Date: Fri, 6 Sep 2024 17:43:48 -0400 Subject: [PATCH 17/20] Renamed md5sum to zarr_manifest_md5sum for clarity, remove equality test from the v2 dataset and moved the verify_checksum parameter to v1 --- polaris/dataset/_base.py | 15 +--------- polaris/dataset/_dataset.py | 17 +++++++++++ polaris/experimental/_dataset_v2.py | 46 ++++++++++------------------- tests/test_dataset_v2.py | 28 +++--------------- 4 files changed, 38 insertions(+), 68 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 227e3868..b271901a 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -313,20 +313,13 @@ def to_json( """ raise NotImplementedError - def cache(self, verify_checksum: bool = True) -> str: + def cache(self) -> str: """Caches the dataset by downloading all additional data for pointer columns to a local directory. - Args: - verify_checksum: Whether to verify the checksum of the dataset after caching. - Returns: The path to the cache directory. """ self.to_json(self.cache_dir, load_zarr_from_new_location=True) - - if verify_checksum: - self.verify_checksum() - return self.cache_dir def size(self) -> tuple[int, int]: @@ -361,12 +354,6 @@ def __repr__(self): def __str__(self): return self.__repr__() - def __eq__(self, other): - """Whether two datasets are equal is solely determined by the checksum""" - if not isinstance(other, BaseDataset): - return False - return self.md5sum == other.md5sum - def __del__(self): """Close the connection of the client""" if self._client is not None: diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 255bcdcf..378dd65e 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -265,6 +265,17 @@ def to_json( return dataset_path + def cache(self, verify_checksum: bool = False): + """Cache the dataset to the cache directory. + + Args: + verify_checksum: Whether to verify the checksum of the dataset after caching. + """ + dst = super().cache() + if verify_checksum: + self.verify_checksum() + return dst + def _split_index_from_path(self, path: str) -> tuple[str, int | None]: """ Paths can have an additional index appended to them. @@ -287,3 +298,9 @@ def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"table", "zarr_md5sum_manifest"}) return repr_dict + + def __eq__(self, other): + """Whether two datasets are equal is solely determined by the checksum""" + if not isinstance(other, DatasetV1): + return False + return self.md5sum == other.md5sum diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index cdda80f1..f3d971b2 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -42,8 +42,8 @@ class DatasetV2(BaseDataset): """ version: ClassVar[Literal[2]] = 2 - _zarr_md5sum_manifest_path: str | None = PrivateAttr(None) - _md5sum: str | None = PrivateAttr(None) + _zarr_manifest_path: str | None = PrivateAttr(None) + _zarr_manifest_md5sum: str | None = PrivateAttr(None) # Redefine this to make it a required field zarr_root_path: str @@ -80,7 +80,7 @@ def _validate_v2_dataset_model(cls, m: "DatasetV2"): # Set the default cache dir if none and make sure it exists if m.cache_dir is None: - dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + dataset_id = m._zarr_manifest_md5sum if m.has_zarr_manifest_md5sum else str(uuid.uuid4()) m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id m.cache_dir.mkdir(parents=True, exist_ok=True) @@ -121,36 +121,36 @@ def dtypes(self) -> dict[str, np.dtype]: @property def zarr_manifest_path(self) -> str: - if self._zarr_md5sum_manifest_path is None: + if self._zarr_manifest_path is None: zarr_manifest_path = generate_zarr_manifest(self.zarr_root_path, self.cache_dir) - self._zarr_md5sum_manifest_path = zarr_manifest_path + self._zarr_manifest_path = zarr_manifest_path - return self._zarr_md5sum_manifest_path + return self._zarr_manifest_path @computed_field @property - def md5sum(self) -> str: + def zarr_manifest_md5sum(self) -> str: """ Lazily compute the checksum once needed. The checksum of the DatasetV2 is computed from the Zarr Manifest and is _not_ deterministic. """ - if not self.has_md5sum: + if not self.has_zarr_manifest_md5sum: logger.info("Computing the checksum. This can be slow for large datasets.") - self.md5sum = calculate_file_md5(self.zarr_manifest_path) - return self._md5sum + self.zarr_manifest_md5sum = calculate_file_md5(self.zarr_manifest_path) + return self._zarr_manifest_md5sum - @md5sum.setter - def md5sum(self, value: str): + @zarr_manifest_md5sum.setter + def zarr_manifest_md5sum(self, value: str): """Set the checksum.""" if not re.fullmatch(r"^[a-f0-9]{32}$", value): raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") - self._md5sum = value + self._zarr_manifest_md5sum = value @property - def has_md5sum(self) -> bool: - """Whether the md5sum for this class has been computed and stored.""" - return self._md5sum is not None + def has_zarr_manifest_md5sum(self) -> bool: + """Whether the md5sum for this dataset's zarr manifest has been computed and stored.""" + return self._zarr_manifest_md5sum is not None def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = None) -> np.ndarray | Any: """Indexes the Zarr archive. @@ -243,20 +243,6 @@ def to_json( json.dump(serialized, f) return dataset_path - def cache(self) -> str: - """Caches the dataset by downloading all additional data for pointer columns to a local directory. - - Args: - verify_checksum: Whether to verify the checksum of the dataset after caching. - - Returns: - The path to the cache directory. - """ - # NOTE (cwognum): We don't support a deterministic checksum for the Dataset V2 yet, - # so verification doesn't make sense. See also: - # https://github.com/polaris-hub/polaris/issues/188 - super().cache(verify_checksum=False) - def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump(exclude={"zarr_md5sum_manifest"}) diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index c1ef2b5f..d0141152 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -65,33 +65,13 @@ def test_dataset_v2_load_to_memory(test_dataset_v2): assert d2 < d1 -def test_dataset_v2_checksum(test_dataset_v2): - # Make sure the `md5sum` is part of the model dump even if not initiated yet. - # This is important for uploads to the Hub. - assert test_dataset_v2._md5sum is None - assert "md5sum" in test_dataset_v2.model_dump() - - # (1) Without any changes, same hash - kwargs = test_dataset_v2.model_dump() - assert DatasetV2(**kwargs) == test_dataset_v2 - - # (2) With unimportant changes, same hash - kwargs["name"] = "changed" - kwargs["description"] = "changed" - kwargs["source"] = "https://changed.com" - assert DatasetV2(**kwargs) == test_dataset_v2 - - # (3) Without any changes, but different hash - dataset = DatasetV2(**kwargs) - dataset._md5sum = "invalid" - assert dataset != test_dataset_v2 - - def test_dataset_v2_serialization(test_dataset_v2, tmpdir): save_dir = tmpdir.join("save_dir") path = test_dataset_v2.to_json(save_dir) new_dataset = DatasetV2.from_json(path) - assert test_dataset_v2 == new_dataset + for i in range(5): + assert np.array_equal(new_dataset.get_data(i, "A"), test_dataset_v2.get_data(i, "A")) + assert np.array_equal(new_dataset.get_data(i, "B"), test_dataset_v2.get_data(i, "B")) def test_dataset_v2_caching(test_dataset_v2, tmpdir): @@ -257,7 +237,7 @@ def test_zarr_manifest(test_dataset_v2): assert len(df) == 204 # Assert the manifest hash is calculated - assert test_dataset_v2.md5sum is not None + assert test_dataset_v2.zarr_manifest_md5sum is not None # Add array to Zarr archive to change the number of chunks in the dataset root = zarr.open(test_dataset_v2.zarr_root_path, "a") From 7bf7ac83997fe1d639bc2df6a340afc1e537f252 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 11 Sep 2024 14:17:16 -0400 Subject: [PATCH 18/20] Fix missing import --- polaris/dataset/_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 1e8f9273..c95293ff 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -2,7 +2,7 @@ import uuid from hashlib import md5 from pathlib import Path -from typing import Any, ClassVar, List, Literal, Union +from typing import Any, ClassVar, List, Literal, Optional, Union import fsspec import numpy as np @@ -21,8 +21,8 @@ from polaris.utils.types import ( AccessType, HubOwner, - ZarrConflictResolution, TimeoutTypes, + ZarrConflictResolution, ) # Constants From 6d351221d6de15e8efa582197abe1fea3293020c Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 11 Sep 2024 15:17:57 -0400 Subject: [PATCH 19/20] Added PR feedback --- polaris/benchmark/_base.py | 17 +++++------ polaris/dataset/_base.py | 39 +++++++++++------------- polaris/dataset/_competition_dataset.py | 4 +-- polaris/dataset/_dataset.py | 32 +++++++++++++------- polaris/dataset/_subset.py | 2 +- polaris/experimental/_dataset_v2.py | 40 ++++++++++++++++--------- 6 files changed, 76 insertions(+), 58 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 8f9edbd1..e73d7f35 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -161,7 +161,7 @@ def _validate_main_metric(cls, v): return v @model_validator(mode="after") - def _validate_split(cls, m: "BenchmarkSpecification"): + def _validate_split(self): """ Verifies that: 1) There are no empty test partitions @@ -170,7 +170,7 @@ def _validate_split(cls, m: "BenchmarkSpecification"): 4) There is no overlap between the train and test set 5) No row exists in the test set where all labels are missing/empty """ - split = m.split + split = self.split # Train partition can be empty (zero-shot) # Test partitions cannot be empty @@ -213,13 +213,13 @@ def _validate_split(cls, m: "BenchmarkSpecification"): raise InvalidBenchmarkError("The test set contains duplicate indices") # All indices are valid given the dataset - dataset = m.dataset + dataset = self.dataset if dataset is not None: max_i = len(dataset) if any(i < 0 or i >= max_i for i in chain(train_idx_list, full_test_idx_set)): raise InvalidBenchmarkError("The predefined split contains invalid indices") - return m + return self @field_validator("target_types") def _validate_target_types(cls, v, info: ValidationInfo): @@ -262,15 +262,14 @@ def _validate_target_types(cls, v, info: ValidationInfo): return v @model_validator(mode="after") - @classmethod - def _validate_model(cls, m: "BenchmarkSpecification"): + def _validate_model(self): """ Sets a default metric if missing. """ # Set a default main metric if not set yet - if m.main_metric is None: - m.main_metric = m.metrics[0] - return m + if self.main_metric is None: + self.main_metric = self.metrics[0] + return self @field_serializer("metrics", "main_metric") def _serialize_metrics(self, v): diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index b271901a..04c173b7 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,9 +1,8 @@ import abc import json from pathlib import Path -from typing import Any, Dict, List, MutableMapping, Optional, Union +from typing import Any, Dict, MutableMapping, Optional, Union -import fsspec import numpy as np import zarr from loguru import logger @@ -92,7 +91,7 @@ def _validate_adapters(cls, value): return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} @field_serializer("default_adapters") - def _serialize_adapters(self, value: List[Adapter]): + def _serialize_adapters(self, value: dict[str, Adapter]): """Serializes the adapters""" return {k: v.name for k, v in value.items()} @@ -104,26 +103,26 @@ def _serialize_paths(value): return value @model_validator(mode="after") - def _validate_base_dataset_model(cls, m: "BaseDataset"): + def _validate_base_dataset_model(self): # Verify that all annotations are for columns that exist - if any(k not in m.columns for k in m.annotations): + if any(k not in self.columns for k in self.annotations): raise InvalidDatasetError( - f"There are annotations for columns that do not exist. Columns: {m.columns}. Annotations: {list(m.annotations.keys())}" + f"There are annotations for columns that do not exist. Columns: {self.columns}. Annotations: {list(self.annotations.keys())}" ) # Verify that all adapters are for columns that exist - if any(k not in m.columns for k in m.default_adapters.keys()): + if any(k not in self.columns for k in self.default_adapters.keys()): raise InvalidDatasetError( - f"There are default adapters for columns that do not exist. Columns: {m.columns}. Adapters: {list(m.annotations.keys())}" + f"There are default adapters for columns that do not exist. Columns: {self.columns}. Adapters: {list(self.annotations.keys())}" ) # Set a default for missing annotations and convert strings to Modality - for c in m.columns: - if c not in m.annotations: - m.annotations[c] = ColumnAnnotation() - m.annotations[c].dtype = m.dtypes[c] + for c in self.columns: + if c not in self.annotations: + self.annotations[c] = ColumnAnnotation() + self.annotations[c].dtype = self.dtypes[c] - return m + return self @property def client(self): @@ -277,19 +276,15 @@ def upload_to_hub(self, access: AccessType = "private", owner: Union[HubOwner, s """Uploads the dataset to the Polaris Hub.""" raise NotImplementedError - @classmethod + @abc.abstractclassmethod def from_json(cls, path: str): - """Loads a dataset from a JSON file. - Overrides the method from the base class to remove the caching dir from the file to load from, - as that should be user dependent. + """ + Loads a dataset from a JSON file. Args: - path: Loads a dataset specification from a JSON file. + path: The path to the JSON file to load the dataset from. """ - with fsspec.open(path, "r") as f: - data = json.load(f) - data.pop("cache_dir", None) - return cls.model_validate(data) + raise NotImplementedError @abc.abstractmethod def to_json( diff --git a/polaris/dataset/_competition_dataset.py b/polaris/dataset/_competition_dataset.py index 7c642e90..a7a9a532 100644 --- a/polaris/dataset/_competition_dataset.py +++ b/polaris/dataset/_competition_dataset.py @@ -13,8 +13,8 @@ class CompetitionDataset(DatasetV1): """ @model_validator(mode="after") - def _validate_model(cls, m: "CompetitionDataset"): + def _validate_model(self): """We reject the instantiation of competition datasets which leverage Zarr for the time being""" - if m.uses_zarr: + if self.uses_zarr: raise InvalidCompetitionError("Pointer columns are not currently supported in competitions.") diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index c95293ff..299edba5 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -78,25 +78,24 @@ def _validate_table(cls, v): return v @model_validator(mode="after") - @classmethod - def _validate_v1_dataset_model(cls, m: "DatasetV1"): + def _validate_v1_dataset_model(self): """Verifies some dependencies between properties""" - has_pointers = any(anno.is_pointer for anno in m.annotations.values()) - if has_pointers and m.zarr_root_path is None: + has_pointers = any(anno.is_pointer for anno in self.annotations.values()) + if has_pointers and self.zarr_root_path is None: raise InvalidDatasetError("A zarr_root_path needs to be specified when there are pointer columns") - if not has_pointers and m.zarr_root_path is not None: + if not has_pointers and self.zarr_root_path is not None: raise InvalidDatasetError( "The zarr_root_path should only be specified when there are pointer columns" ) # Set the default cache dir if none and make sure it exists - if m.cache_dir is None: - dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) - m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - m.cache_dir.mkdir(parents=True, exist_ok=True) + if self.cache_dir is None: + dataset_id = self._md5sum if self.has_md5sum else str(uuid.uuid4()) + self.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + self.cache_dir.mkdir(parents=True, exist_ok=True) - return m + return self def _compute_checksum(self) -> str: """Computes a hash of the dataset. @@ -209,6 +208,19 @@ def upload_to_hub( """ self.client.upload_dataset(self, access=access, owner=owner, timeout=timeout) + @classmethod + def from_json(cls, path: str): + """ + Loads a dataset from a JSON file. + + Args: + path: The path to the JSON file to load the dataset from. + """ + with fsspec.open(path, "r") as f: + data = json.load(f) + data.pop("cache_dir", None) + return cls.model_validate(data) + def to_json( self, destination: str, diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 05cbee80..dfb95869 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -65,7 +65,7 @@ def __init__( indices: List[int | Sequence[int]], input_cols: List[str] | str, target_cols: List[str] | str, - adapters: List[Adapter] | None = None, + adapters: dict[str, Adapter] | None = None, featurization_fn: Callable | None = None, hide_targets: bool = False, ): diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index f3d971b2..dc34aff5 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -49,42 +49,41 @@ class DatasetV2(BaseDataset): zarr_root_path: str @model_validator(mode="after") - @classmethod - def _validate_v2_dataset_model(cls, m: "DatasetV2"): + def _validate_v2_dataset_model(self): """Verifies some dependencies between properties""" # Since the keys for subgroups are not ordered, we have no easy way to index these groups. # Any subgroup should therefore have a special array that defines the index for that group. - for group in m.zarr_root.group_keys(): - if _INDEX_ARRAY_KEY not in m.zarr_root[group].array_keys(): + for group in self.zarr_root.group_keys(): + if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys(): raise InvalidDatasetError(f"Group {group} does not have an index array.") - index_arr = m.zarr_root[group][_INDEX_ARRAY_KEY] - if len(index_arr) != len(m.zarr_root[group]) - 1: + index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY] + if len(index_arr) != len(self.zarr_root[group]) - 1: raise InvalidDatasetError( f"Length of index array for group {group} does not match the size of the group." ) - if any(x not in m.zarr_root[group] for x in index_arr): + if any(x not in self.zarr_root[group] for x in index_arr): raise InvalidDatasetError( f"Keys of index array for group {group} does not match the group members." ) # Check the structure of the Zarr archive # All arrays or groups in the root should have the same length. - lengths = {len(m.zarr_root[k]) for k in m.zarr_root.array_keys()} - lengths.update({len(m.zarr_root[k][_INDEX_ARRAY_KEY]) for k in m.zarr_root.group_keys()}) + lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()} + lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()}) if len(lengths) > 1: raise InvalidDatasetError( f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" ) # Set the default cache dir if none and make sure it exists - if m.cache_dir is None: - dataset_id = m._zarr_manifest_md5sum if m.has_zarr_manifest_md5sum else str(uuid.uuid4()) - m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - m.cache_dir.mkdir(parents=True, exist_ok=True) + if self.cache_dir is None: + dataset_id = self._zarr_manifest_md5sum if self.has_zarr_manifest_md5sum else str(uuid.uuid4()) + self.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + self.cache_dir.mkdir(parents=True, exist_ok=True) - return m + return self @property def n_rows(self) -> int: @@ -192,6 +191,19 @@ def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | # to do it simultaneously with a PR on the Hub side. raise NotImplementedError + @classmethod + def from_json(cls, path: str): + """ + Loads a dataset from a JSON file. + + Args: + path: The path to the JSON file to load the dataset from. + """ + with fsspec.open(path, "r") as f: + data = json.load(f) + data.pop("cache_dir", None) + return cls.model_validate(data) + def to_json( self, destination: str, From 8ae8e5ec7139d5ee20d7fe0c458a56b751ce9b75 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 11 Sep 2024 15:21:10 -0400 Subject: [PATCH 20/20] Update decorators --- polaris/dataset/_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 04c173b7..7e5e6133 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -276,7 +276,8 @@ def upload_to_hub(self, access: AccessType = "private", owner: Union[HubOwner, s """Uploads the dataset to the Polaris Hub.""" raise NotImplementedError - @abc.abstractclassmethod + @classmethod + @abc.abstractmethod def from_json(cls, path: str): """ Loads a dataset from a JSON file.