Skip to content

royerlab/contrastive-td

Repository files navigation

Contrastive TD

A toolbox for contrastive learning on tracksdata graphs.

Installation

pip install git+https://github.com/royerlab/contrastive-td.git

Quick Start

import torch
import tracksdata as td
from contrastive_td.data import TripletDataset
from contrastive_td.fitting import training_loop

# Load your tracking graph
graph = td.graph.InMemoryGraph("path/to/tracking/data")

# Create triplet dataset
dataset = TripletDataset(
    graph=graph,
    node_feature_key="features",
    edge_ground_truth_key="is_true_link"
)

# Define your embedding model
model = torch.nn.Sequential(
    torch.nn.Linear(input_dim, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64)
)

# Train with your loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
training_loop(
    dataset=dataset,
    model=model,
    epochs=10,
    loss_func=your_triplet_loss,
    opt=optimizer
)

About

A toolbox for contrastive learning on tracksdata graphs.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages