From c633d8f0d0fa2266bae452104bf1af2f220f0f49 Mon Sep 17 00:00:00 2001 From: Jonas De Schouwer Date: Sat, 1 Nov 2025 14:06:01 -0700 Subject: [PATCH 1/4] added driver-race-compete task to rel-f1 dataset --- relbench/tasks/__init__.py | 1 + relbench/tasks/f1.py | 50 ++++++++++++++++++++++++++++++++++++-- relbench/tasks/hashes.json | 1 + 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index 804109f2..d50cf69e 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -199,6 +199,7 @@ def get_task(dataset_name: str, task_name: str, download=False) -> BaseTask: register_task("rel-f1", "driver-position", f1.DriverPositionTask) register_task("rel-f1", "driver-dnf", f1.DriverDNFTask) register_task("rel-f1", "driver-top3", f1.DriverTop3Task) +register_task("rel-f1", "driver-race-compete", f1.DriverRaceCompeteTask) register_task( "rel-f1", "results-position", diff --git a/relbench/tasks/f1.py b/relbench/tasks/f1.py index 2896f693..39938368 100644 --- a/relbench/tasks/f1.py +++ b/relbench/tasks/f1.py @@ -1,8 +1,8 @@ import duckdb import pandas as pd -from relbench.base import Database, EntityTask, Table, TaskType -from relbench.metrics import accuracy, average_precision, f1, mae, r2, rmse, roc_auc +from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType +from relbench.metrics import accuracy, average_precision, f1, link_prediction_map, link_prediction_precision, link_prediction_recall, mae, r2, rmse, roc_auc class DriverPositionTask(EntityTask): @@ -160,3 +160,49 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab pkey_col=None, time_col=self.time_col, ) + +class DriverRaceCompeteTask(RecommendationTask): + r"""Predict in which races a driver will compete in the next 1 year.""" + + task_type = TaskType.LINK_PREDICTION + src_entity_col = "driverId" + src_entity_table = "drivers" + dst_entity_col = "raceId" + dst_entity_table = "races" + target_col = "raceId" + time_col = "date" + timedelta = pd.Timedelta(days=365) + metrics = [link_prediction_precision, link_prediction_recall, link_prediction_map] + eval_k = 10 + + def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Table: + timestamp_df = pd.DataFrame({"timestamp": timestamps}) + results = db.table_dict["results"].df + + df = duckdb.sql( + f""" + SELECT + t.timestamp as date, + re.driverId as driverId, + LIST(DISTINCT re.raceId) as raceId + FROM + timestamp_df t + LEFT JOIN + results re + ON + re.date <= t.timestamp + INTERVAL '{self.timedelta}' + and re.date > t.timestamp + GROUP BY t.timestamp, re.driverId + ; + """ + ).df() + + return Table( + df=df, + fkey_col_to_pkey_table={ + self.src_entity_col: self.src_entity_table, + self.dst_entity_col: self.dst_entity_table, + }, + pkey_col=None, + time_col=self.time_col, + ) diff --git a/relbench/tasks/hashes.json b/relbench/tasks/hashes.json index 67480633..958f4290 100644 --- a/relbench/tasks/hashes.json +++ b/relbench/tasks/hashes.json @@ -16,6 +16,7 @@ "rel-f1/tasks/driver-dnf.zip": "58553e0ecebff60e9f8c12202ae2d1109b206b6f8a1ec0589af2540ac2982178", "rel-f1/tasks/driver-position.zip": "775b28a51604169539bbe712a2f0d15158c112bc6abf316cdd0995087a7ae03e", "rel-f1/tasks/driver-top3.zip": "1a16abf993cbe58524054cf710bada8538e7c75d6f8f388c0b137ab3575a9a47", + "rel-f1/tasks/driver-race-compete.zip": "2485543002c31baedb4f8f6739420a508b940a58d25239f5412284bb0ff9c2cc", "rel-hm/tasks/item-sales.zip": "92a2c71ebd6dc5ab67c14c33a3a45c9ccafee0e5f0c7c698871a88f74e8a0867", "rel-hm/tasks/user-churn.zip": "2ef2030e308c57b5bcb4b2df1458cc3c21b7286e0b658d222be010b4f90e9265", "rel-hm/tasks/user-item-purchase.zip": "c8a8bb98e1b94bb612cc2694676d9d53a29743f11f19b8829ded5a725b7afab9", From 3deedcfade9e8d64626a62cc1f4a2db413190e15 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Nov 2025 21:11:53 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- relbench/tasks/f1.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/relbench/tasks/f1.py b/relbench/tasks/f1.py index 39938368..47e3fa03 100644 --- a/relbench/tasks/f1.py +++ b/relbench/tasks/f1.py @@ -2,7 +2,18 @@ import pandas as pd from relbench.base import Database, EntityTask, RecommendationTask, Table, TaskType -from relbench.metrics import accuracy, average_precision, f1, link_prediction_map, link_prediction_precision, link_prediction_recall, mae, r2, rmse, roc_auc +from relbench.metrics import ( + accuracy, + average_precision, + f1, + link_prediction_map, + link_prediction_precision, + link_prediction_recall, + mae, + r2, + rmse, + roc_auc, +) class DriverPositionTask(EntityTask): @@ -161,6 +172,7 @@ def make_table(self, db: Database, timestamps: "pd.Series[pd.Timestamp]") -> Tab time_col=self.time_col, ) + class DriverRaceCompeteTask(RecommendationTask): r"""Predict in which races a driver will compete in the next 1 year.""" From 285985b35310bef7129309b58287f8137bd53a68 Mon Sep 17 00:00:00 2001 From: Jonas De Schouwer Date: Fri, 21 Nov 2025 17:14:03 -0800 Subject: [PATCH 3/4] setting RelBench cache dir based on RELBENCH_CACHE_DIR env variable --- README.md | 2 +- pyproject.toml | 2 +- relbench/datasets/__init__.py | 6 ++++-- relbench/tasks/__init__.py | 5 ++++- relbench/utils.py | 3 +++ 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 1cca8c78..c4efa6ed 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ dataset: Dataset = get_dataset("rel-amazon", download=True)
Details on downloading and caching behavior. -RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way. +RelBench datasets (and tasks) are cached to disk (usually at `~/.cache/relbench`, the location can be set using the `RELBENCH_CACHE_DIR` environment variable). If not present in cache, `download=True` downloads the data, verifies it against the known hash, and caches it. If present, `download=True` performs the verification and avoids downloading if verification succeeds. This is the recommended way. `download=False` uses the cached data without verification, if present, or processes and caches the data from scratch / raw sources otherwise.
diff --git a/pyproject.toml b/pyproject.toml index 6586ded6..874a7ea6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "relbench" -version = "1.1.0_dev_20250721" +version = "1.1.0_dev_20251121" description = "RelBench: Relational Deep Learning Benchmark" authors = [{name = "RelBench Team", email = "relbench@cs.stanford.edu"}] readme = "README.md" diff --git a/relbench/datasets/__init__.py b/relbench/datasets/__init__.py index 9da0bf29..3ba32836 100644 --- a/relbench/datasets/__init__.py +++ b/relbench/datasets/__init__.py @@ -18,6 +18,7 @@ stack, trial, ) +from relbench.utils import get_relbench_cache_dir dataset_registry = {} @@ -28,6 +29,7 @@ path=pooch.os_cache("relbench"), base_url="https://relbench.stanford.edu/download/", registry=hashes, + env="RELBENCH_CACHE_DIR" ) @@ -50,7 +52,7 @@ def register_dataset( can pass `cache_dir` as a keyword argument in `kwargs`. """ - cache_dir = f"{pooch.os_cache('relbench')}/{name}" + cache_dir = f"{get_relbench_cache_dir()}/{name}" kwargs = {"cache_dir": cache_dir, **kwargs} dataset_registry[name] = (cls, args, kwargs) @@ -110,7 +112,7 @@ def get_dataset(name: str, download=True) -> Dataset: cls, args, kwargs = ( mimic.MimicDataset, (), - {"cache_dir": f"{pooch.os_cache('relbench')}/{name}"}, + {"cache_dir": f"{get_relbench_cache_dir()}/{name}"}, ) else: cls, args, kwargs = dataset_registry[name] diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index d50cf69e..3855f55b 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -3,6 +3,7 @@ from collections import defaultdict from functools import lru_cache from typing import List +import os import pooch @@ -20,6 +21,7 @@ stack, trial, ) +from relbench.utils import get_relbench_cache_dir task_registry = defaultdict(dict) @@ -30,6 +32,7 @@ path=pooch.os_cache("relbench"), base_url="https://relbench.stanford.edu/download/", registry=hashes, + env="RELBENCH_CACHE_DIR" ) @@ -54,7 +57,7 @@ def register_task( can pass `cache_dir` as a keyword argument in `kwargs`. """ - cache_dir = f"{pooch.os_cache('relbench')}/{dataset_name}/tasks/{task_name}" + cache_dir = f"{get_relbench_cache_dir()}/{dataset_name}/tasks/{task_name}" kwargs = {"cache_dir": cache_dir, **kwargs} task_registry[dataset_name][task_name] = (cls, args, kwargs) diff --git a/relbench/utils.py b/relbench/utils.py index dfc20c90..92957355 100644 --- a/relbench/utils.py +++ b/relbench/utils.py @@ -67,3 +67,6 @@ def clean_datetime(df: pd.DataFrame, col: str) -> pd.DataFrame: f"{percentage_removed:.2f}%" ) return df + +def get_relbench_cache_dir() -> str: + return os.getenv("RELBENCH_CACHE_DIR") or pooch.os_cache("relbench") \ No newline at end of file From 537be6ec8105aab502dd79cf45922d3642e7d260 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Nov 2025 01:49:31 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- relbench/datasets/__init__.py | 2 +- relbench/tasks/__init__.py | 4 ++-- relbench/utils.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/relbench/datasets/__init__.py b/relbench/datasets/__init__.py index 3ba32836..79c37b31 100644 --- a/relbench/datasets/__init__.py +++ b/relbench/datasets/__init__.py @@ -29,7 +29,7 @@ path=pooch.os_cache("relbench"), base_url="https://relbench.stanford.edu/download/", registry=hashes, - env="RELBENCH_CACHE_DIR" + env="RELBENCH_CACHE_DIR", ) diff --git a/relbench/tasks/__init__.py b/relbench/tasks/__init__.py index 3855f55b..a7417f6e 100644 --- a/relbench/tasks/__init__.py +++ b/relbench/tasks/__init__.py @@ -1,9 +1,9 @@ import json +import os import pkgutil from collections import defaultdict from functools import lru_cache from typing import List -import os import pooch @@ -32,7 +32,7 @@ path=pooch.os_cache("relbench"), base_url="https://relbench.stanford.edu/download/", registry=hashes, - env="RELBENCH_CACHE_DIR" + env="RELBENCH_CACHE_DIR", ) diff --git a/relbench/utils.py b/relbench/utils.py index 92957355..acc75139 100644 --- a/relbench/utils.py +++ b/relbench/utils.py @@ -68,5 +68,6 @@ def clean_datetime(df: pd.DataFrame, col: str) -> pd.DataFrame: ) return df + def get_relbench_cache_dir() -> str: - return os.getenv("RELBENCH_CACHE_DIR") or pooch.os_cache("relbench") \ No newline at end of file + return os.getenv("RELBENCH_CACHE_DIR") or pooch.os_cache("relbench")