diff --git a/README.md b/README.md index 1cca8c78..c4efa6ed 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ dataset: Dataset = get_dataset("rel-amazon", download=True)
Details on downloading and caching behavior. -RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way. +RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`, the location can be set using the `RELBENCH_CACHE_DIR` environment variable). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way. `download=False` uses the cached data without verification, if present, or processes and caches the data from scratch / raw sources otherwise.
diff --git a/pyproject.toml b/pyproject.toml index 6586ded6..874a7ea6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "relbench" -version = "1.1.0_dev_20250721" +version = "1.1.0_dev_20251121" description = "RelBench: Relational Deep Learning Benchmark" authors = [{name = "RelBench Team", email = "relbench@cs.stanford.edu"}] readme = "README.md" diff --git a/relbench/datasets/__init__.py b/relbench/datasets/__init__.py index 9da0bf29..79c37b31 100644 --- a/relbench/datasets/__init__.py +++ b/relbench/datasets/__init__.py @@ -18,6 +18,7 @@ stack, trial, ) +from relbench.utils import get_relbench_cache_dir dataset_registry = {} @@ -28,6 +29,7 @@ path=pooch.os_cache("relbench"), base_url="https://relbench.stanford.edu/download/", registry=hashes, + env="RELBENCH_CACHE_DIR", ) @@ -50,7 +52,7 @@ def register_dataset( can pass `cache_dir` as a keyword argument in `kwargs`. """ - cache_dir = f"{pooch.os_cache('relbench')}/{name}" + cache_dir = f"{get_relbench_cache_dir()}/{name}" kwargs = {"cache_dir": cache_dir, **kwargs} dataset_registry[name] = (cls, args, kwargs) @@ -110,7 +112,7 @@ def get_dataset(name: str, download=True) -> Dataset: cls, args, kwargs = ( mimic.MimicDataset, (), - {"cache_dir": f"{pooch.os_cache('relbench')}/{name}"}, + {"cache_dir": f"{get_relbench_cache_dir()}/{name}"}, ) else: cls, args, kwargs = dataset_registry[name] diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index d50cf69e..a7417f6e 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -1,4 +1,5 @@ import json +import os import pkgutil from collections import defaultdict from functools import lru_cache @@ -20,6 +21,7 @@ stack, trial, ) +from relbench.utils import get_relbench_cache_dir task_registry = defaultdict(dict) @@ -30,6 +32,7 @@ path=pooch.os_cache("relbench"), base_url="https://relbench.stanford.edu/download/", registry=hashes, + env="RELBENCH_CACHE_DIR", ) @@ -54,7 +57,7 @@ def register_task( can pass `cache_dir` as a keyword argument in `kwargs`. """ - cache_dir = f"{pooch.os_cache('relbench')}/{dataset_name}/tasks/{task_name}" + cache_dir = f"{get_relbench_cache_dir()}/{dataset_name}/tasks/{task_name}" kwargs = {"cache_dir": cache_dir, **kwargs} task_registry[dataset_name][task_name] = (cls, args, kwargs) diff --git a/relbench/utils.py b/relbench/utils.py index dfc20c90..acc75139 100644 --- a/relbench/utils.py +++ b/relbench/utils.py @@ -67,3 +67,7 @@ def clean_datetime(df: pd.DataFrame, col: str) -> pd.DataFrame: f"{percentage_removed:.2f}%" ) return df + + +def get_relbench_cache_dir() -> str: + return os.getenv("RELBENCH_CACHE_DIR") or pooch.os_cache("relbench")