Unified API for loading 29 vision datasets and extracting embeddings from any PyTorch model. Handles torchvision, HuggingFace, and custom ImageFolder datasets through a single interface. Ships with 12 built-in embedding models (CLIP, DINOv2, DINOv3) and supports async extraction on secondary GPUs.
pip install embedata # core (torchvision datasets)
pip install embedata[hf] # + HuggingFace datasets (ImageNet, CUB, RESISC45)From source:
git clone https://github.com/ayghri/embedata.git && cd embedata
pip install -e ".[hf]"from embedata import list_datasets, get_dataset, get_datasets, get_dataloaders
print(list_datasets()) # all 29 datasets
train_ds = get_dataset("cifar10", split="train", root_dir="./data")
train_ds, val_ds = get_datasets("cifar10", root_dir="./data")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=128, root_dir="./data")Most datasets auto-download on first use. Datasets that require manual setup (Kaggle downloads, frame extraction, etc.) document their steps via get_spec("dataset_name").notes.
Pass any PyTorch model that maps images to feature vectors. Embeddings are saved as .npy files:
import torch
from embedata import get_dataloaders, extract
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
model.eval().to("cuda:0")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=256, root_dir="./data")
extract(train_loader, model, device=torch.device("cuda:0"),
save_dir="./representations/cifar10/dinov2", suffix="train")
extract(val_loader, model, device=torch.device("cuda:0"),
save_dir="./representations/cifar10/dinov2", suffix="val")Load them back as a PyTorch Dataset:
from embedata import load_embeddings
ds = load_embeddings("cifar10", "dinov2", repr_dir="./representations", split="train")
feat, label = ds[0]For on-the-fly extraction without disk I/O, EmbeddingDataLoader runs a model on a secondary device in a background thread:
from embedata import EmbeddingDataLoader, get_dataset
dataset = get_dataset("cifar10", split="train", root_dir="./data")
loader = EmbeddingDataLoader(dataset, model, device="cuda:1", batch_size=256, prefetch=2)
for embeddings, labels in loader:
...The package includes 12 pre-registered models accessible via load_model.
Eight CLIP variants (clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14),
DINOv2 ViT-g/14 (dinov2), and three DINOv3 variants (dinov3s, dinov3b, dinov3l).
from embedata import list_models, load_model
model, preprocess = load_model("clipvitL14", device="cuda:0", models_dir="./models")The returned preprocess is the model's image transform -- pass it to get_dataloaders or get_dataset. For DINOv2, preprocess is None; use get_default_transforms() instead.
CLIP models require pip install git+https://github.com/openai/CLIP.git. DINOv3 models require pip install transformers.
Register your own dataset loader with the @register decorator. The function receives a torchvision transform and a data path, and returns (train_dataset, val_dataset):
from embedata import register
@register("my_dataset", notes="Setup instructions shown by get_spec()")
def _my_dataset(transform, data_path):
train_ds = ...
val_ds = ...
return train_ds, val_dsThree datasets (ImageNet, CUB, RESISC45) load via HuggingFace and support streaming mode (requires pip install embedata[hf]). ImageNet-1k is gated and requires huggingface-cli login.
train_ds, val_ds = get_datasets("imagenet", streaming=True)Some datasets require manual download and preparation before use. Python prepare scripts are bundled with the package:
python -m embedata.prepare.birdsnap --root_dir ROOT_DIR
python -m embedata.prepare.fer2013 --root_dir ROOT_DIR
python -m embedata.prepare.ucf101 --root_dir ROOT_DIR [--download]
python -m embedata.prepare.hatefulmemes --root_dir ROOT_DIRShell scripts for cars, eurosat, and sun397 are included under embedata/prepare/. All scripts expect raw data under ROOT_DIR/datasets/{dataset_name}/ and write prepared splits to the same location.
| Dataset | Train | Val/Test | Classes | Size | Source | Notes |
|---|---|---|---|---|---|---|
| aircraft | 6,667 | 3,333 | 100 | variable | FGVCAircraft | trainval / test |
| birdsnap | ~25,000 | ~24,829 | 500 | variable | ImageFolder | manual download + prepare |
| caltech101 | ~3,060 | ~5,587 | 101 | variable | Caltech101 | 30/class for train |
| cars | 8,144 | 8,041 | 196 | variable | StanfordCars | Kaggle download |
| cifar10 | 50,000 | 10,000 | 10 | 32x32 | CIFAR10 | auto-download |
| cifar100 | 50,000 | 10,000 | 100 | 32x32 | CIFAR100 | auto-download |
| clevr | 70,000 | 15,000 | 11 | 320x240 | CLEVRClassification | count 0-10 |
| country211 | 42,200 | 21,100 | 211 | variable | Country211 | train+valid / test |
| cub | 5,994 | 5,794 | 200 | variable | HuggingFace | CUB-200-2011 |
| dtd | 3,760 | 1,880 | 47 | variable | DTD | train+val / test |
| eurosat | 10,000 | 5,000 | 10 | 64x64 | EuroSAT | 1k+500 per class |
| fashionmnist | 60,000 | 10,000 | 10 | 28x28 | FashionMNIST | auto-download |
| fer2013 | 28,709 | 3,589 | 7 | 48x48 | FER2013 | Kaggle + prepare |
| flowers | 2,040 | 6,149 | 102 | variable | Flowers102 | train+val / test |
| food101 | 75,750 | 25,250 | 101 | variable | Food101 | auto-download |
| gtsrb | 26,640 | 12,630 | 43 | variable | GTSRB | auto-download |
| hatefulmemes | ~8,500 | ~500 | 2 | variable | ImageFolder | Kaggle + prepare |
| imagenet | 1,281,167 | 50,000 | 1,000 | variable | HuggingFace | gated, HF login |
| imagenette | 9,469 | 3,925 | 10 | variable | Imagenette | ImageNet subset |
| kinetics700 | varies | varies | 700 | variable | ImageFolder | frame extraction |
| kitti | varies | varies | 4 | variable | ImageFolder | manual prep |
| mnist | 60,000 | 10,000 | 10 | 28x28 | MNIST | auto-download |
| pcam | 294,912 | 32,768 | 2 | 96x96 | PCAM | train+val / test |
| pets | 3,680 | 3,669 | 37 | variable | OxfordIIITPet | auto-download |
| resisc45 | 25,200 | 6,300 | 45 | 256x256 | HuggingFace | remote sensing |
| sst | 7,792 | 1,821 | 2 | variable | RenderedSST2 | train+val / test |
| stl10 | 5,000 | 8,000 | 10 | 96x96 | STL10 | auto-download |
| sun397 | ~19,850 | ~19,850 | 397 | variable | SUN397 | manual + prepare |
| ucf101 | varies | varies | 101 | variable | ImageFolder | frame extraction + prepare |
Counts reflect splits as loaded by embedata (some merge train+val for training). "variable" means images have different native resolutions -- all are resized by the transform (default 224x224). Datasets marked "auto-download" are fetched on first use.
MIT