Skip to content

ayghri/embedata

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

embedata

Unified API for loading 29 vision datasets and extracting embeddings from any PyTorch model. Handles torchvision, HuggingFace, and custom ImageFolder datasets through a single interface. Ships with 12 built-in embedding models (CLIP, DINOv2, DINOv3) and supports async extraction on secondary GPUs.

Install

pip install embedata              # core (torchvision datasets)
pip install embedata[hf]          # + HuggingFace datasets (ImageNet, CUB, RESISC45)

From source:

git clone https://github.com/ayghri/embedata.git && cd embedata
pip install -e ".[hf]"

Quick start

from embedata import list_datasets, get_dataset, get_datasets, get_dataloaders

print(list_datasets())  # all 29 datasets

train_ds = get_dataset("cifar10", split="train", root_dir="./data")
train_ds, val_ds = get_datasets("cifar10", root_dir="./data")
train_loader, val_loader = get_dataloaders("cifar10", batch_size=128, root_dir="./data")

Most datasets auto-download on first use. Datasets that require manual setup (Kaggle downloads, frame extraction, etc.) document their steps via get_spec("dataset_name").notes.

Extracting embeddings

Pass any PyTorch model that maps images to feature vectors. Embeddings are saved as .npy files:

import torch
from embedata import get_dataloaders, extract

model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
model.eval().to("cuda:0")

train_loader, val_loader = get_dataloaders("cifar10", batch_size=256, root_dir="./data")

extract(train_loader, model, device=torch.device("cuda:0"),
        save_dir="./representations/cifar10/dinov2", suffix="train")
extract(val_loader, model, device=torch.device("cuda:0"),
        save_dir="./representations/cifar10/dinov2", suffix="val")

Load them back as a PyTorch Dataset:

from embedata import load_embeddings

ds = load_embeddings("cifar10", "dinov2", repr_dir="./representations", split="train")
feat, label = ds[0]

For on-the-fly extraction without disk I/O, EmbeddingDataLoader runs a model on a secondary device in a background thread:

from embedata import EmbeddingDataLoader, get_dataset

dataset = get_dataset("cifar10", split="train", root_dir="./data")

loader = EmbeddingDataLoader(dataset, model, device="cuda:1", batch_size=256, prefetch=2)

for embeddings, labels in loader:
    ...

Built-in models

The package includes 12 pre-registered models accessible via load_model. Eight CLIP variants (clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14), DINOv2 ViT-g/14 (dinov2), and three DINOv3 variants (dinov3s, dinov3b, dinov3l).

from embedata import list_models, load_model

model, preprocess = load_model("clipvitL14", device="cuda:0", models_dir="./models")

The returned preprocess is the model's image transform -- pass it to get_dataloaders or get_dataset. For DINOv2, preprocess is None; use get_default_transforms() instead.

CLIP models require pip install git+https://github.com/openai/CLIP.git. DINOv3 models require pip install transformers.

Custom datasets

Register your own dataset loader with the @register decorator. The function receives a torchvision transform and a data path, and returns (train_dataset, val_dataset):

from embedata import register

@register("my_dataset", notes="Setup instructions shown by get_spec()")
def _my_dataset(transform, data_path):
    train_ds = ...
    val_ds = ...
    return train_ds, val_ds

HuggingFace streaming

Three datasets (ImageNet, CUB, RESISC45) load via HuggingFace and support streaming mode (requires pip install embedata[hf]). ImageNet-1k is gated and requires huggingface-cli login.

train_ds, val_ds = get_datasets("imagenet", streaming=True)

Dataset preparation

Some datasets require manual download and preparation before use. Python prepare scripts are bundled with the package:

python -m embedata.prepare.birdsnap     --root_dir ROOT_DIR
python -m embedata.prepare.fer2013      --root_dir ROOT_DIR
python -m embedata.prepare.ucf101       --root_dir ROOT_DIR [--download]
python -m embedata.prepare.hatefulmemes --root_dir ROOT_DIR

Shell scripts for cars, eurosat, and sun397 are included under embedata/prepare/. All scripts expect raw data under ROOT_DIR/datasets/{dataset_name}/ and write prepared splits to the same location.

Available datasets

Dataset Train Val/Test Classes Size Source Notes
aircraft 6,667 3,333 100 variable FGVCAircraft trainval / test
birdsnap ~25,000 ~24,829 500 variable ImageFolder manual download + prepare
caltech101 ~3,060 ~5,587 101 variable Caltech101 30/class for train
cars 8,144 8,041 196 variable StanfordCars Kaggle download
cifar10 50,000 10,000 10 32x32 CIFAR10 auto-download
cifar100 50,000 10,000 100 32x32 CIFAR100 auto-download
clevr 70,000 15,000 11 320x240 CLEVRClassification count 0-10
country211 42,200 21,100 211 variable Country211 train+valid / test
cub 5,994 5,794 200 variable HuggingFace CUB-200-2011
dtd 3,760 1,880 47 variable DTD train+val / test
eurosat 10,000 5,000 10 64x64 EuroSAT 1k+500 per class
fashionmnist 60,000 10,000 10 28x28 FashionMNIST auto-download
fer2013 28,709 3,589 7 48x48 FER2013 Kaggle + prepare
flowers 2,040 6,149 102 variable Flowers102 train+val / test
food101 75,750 25,250 101 variable Food101 auto-download
gtsrb 26,640 12,630 43 variable GTSRB auto-download
hatefulmemes ~8,500 ~500 2 variable ImageFolder Kaggle + prepare
imagenet 1,281,167 50,000 1,000 variable HuggingFace gated, HF login
imagenette 9,469 3,925 10 variable Imagenette ImageNet subset
kinetics700 varies varies 700 variable ImageFolder frame extraction
kitti varies varies 4 variable ImageFolder manual prep
mnist 60,000 10,000 10 28x28 MNIST auto-download
pcam 294,912 32,768 2 96x96 PCAM train+val / test
pets 3,680 3,669 37 variable OxfordIIITPet auto-download
resisc45 25,200 6,300 45 256x256 HuggingFace remote sensing
sst 7,792 1,821 2 variable RenderedSST2 train+val / test
stl10 5,000 8,000 10 96x96 STL10 auto-download
sun397 ~19,850 ~19,850 397 variable SUN397 manual + prepare
ucf101 varies varies 101 variable ImageFolder frame extraction + prepare

Counts reflect splits as loaded by embedata (some merge train+val for training). "variable" means images have different native resolutions -- all are resized by the transform (default 224x224). Datasets marked "auto-download" are fetched on first use.

License

MIT

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors