Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ dataset: Dataset = get_dataset("rel-amazon", download=True)
<details>
<summary>Details on downloading and caching behavior.</summary>

RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way.
RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`, the location can be set using the `RELBENCH_CACHE_DIR` environment variable). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way.

`download=False` uses the cached data without verification, if present, or processes and caches the data from scratch / raw sources otherwise.
</details>
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "relbench"
version = "1.1.0_dev_20250721"
version = "1.1.0_dev_20251121"
description = "RelBench: Relational Deep Learning Benchmark"
authors = [{name = "RelBench Team", email = "relbench@cs.stanford.edu"}]
readme = "README.md"
Expand Down
6 changes: 4 additions & 2 deletions relbench/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
stack,
trial,
)
from relbench.utils import get_relbench_cache_dir

dataset_registry = {}

Expand All @@ -28,6 +29,7 @@
path=pooch.os_cache("relbench"),
base_url="https://relbench.stanford.edu/download/",
registry=hashes,
env="RELBENCH_CACHE_DIR",
)


Expand All @@ -50,7 +52,7 @@ def register_dataset(
can pass `cache_dir` as a keyword argument in `kwargs`.
"""

cache_dir = f"{pooch.os_cache('relbench')}/{name}"
cache_dir = f"{get_relbench_cache_dir()}/{name}"
kwargs = {"cache_dir": cache_dir, **kwargs}
dataset_registry[name] = (cls, args, kwargs)

Expand Down Expand Up @@ -110,7 +112,7 @@ def get_dataset(name: str, download=True) -> Dataset:
cls, args, kwargs = (
mimic.MimicDataset,
(),
{"cache_dir": f"{pooch.os_cache('relbench')}/{name}"},
{"cache_dir": f"{get_relbench_cache_dir()}/{name}"},
)
else:
cls, args, kwargs = dataset_registry[name]
Expand Down
5 changes: 4 additions & 1 deletion relbench/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import pkgutil
from collections import defaultdict
from functools import lru_cache
Expand All @@ -20,6 +21,7 @@
stack,
trial,
)
from relbench.utils import get_relbench_cache_dir

task_registry = defaultdict(dict)

Expand All @@ -30,6 +32,7 @@
path=pooch.os_cache("relbench"),
base_url="https://relbench.stanford.edu/download/",
registry=hashes,
env="RELBENCH_CACHE_DIR",
)


Expand All @@ -54,7 +57,7 @@ def register_task(
can pass `cache_dir` as a keyword argument in `kwargs`.
"""

cache_dir = f"{pooch.os_cache('relbench')}/{dataset_name}/tasks/{task_name}"
cache_dir = f"{get_relbench_cache_dir()}/{dataset_name}/tasks/{task_name}"
kwargs = {"cache_dir": cache_dir, **kwargs}
task_registry[dataset_name][task_name] = (cls, args, kwargs)

Expand Down
4 changes: 4 additions & 0 deletions relbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ def clean_datetime(df: pd.DataFrame, col: str) -> pd.DataFrame:
f"{percentage_removed:.2f}%"
)
return df


def get_relbench_cache_dir() -> str:
return os.getenv("RELBENCH_CACHE_DIR") or pooch.os_cache("relbench")