TorchGW – Fast Sampled Gromov-Wasserstein Optimal Transport
Note
Source code: github.com/chansigit/torchgw — clone, star, or open issues on GitHub.
TorchGW is a scalable solver for Gromov-Wasserstein optimal transport, implemented in pure PyTorch with GPU-accelerated Triton fused Sinkhorn kernels.
It aligns two point clouds by matching their internal distance structures — even when the point clouds live in different dimensions — making it ideal for manifold alignment, single-cell multi-omics integration, and cross-domain graph matching.
Key features:
Up to 175x faster than POT on typical workloads (spiral 4000×5000: 1s vs 183s)
Triton fused Sinkhorn — single-pass online logsumexp, no intermediate N×K matrices
Mixed precision — float32 Sinkhorn + float64 output, zero quality loss
Smart early stopping — cost plateau detection, not just transport plan norm
Differentiable — exact gradients via implicit differentiation at the Sinkhorn fixed point
No POT dependency at runtime — pure PyTorch + scipy + scikit-learn
What’s New in v0.4.1
Exact differentiable gradients via implicit differentiation at the Sinkhorn fixed point — fixes a correctness bug where the old backward produced gradients with up to 30x error
New
grad_modeparameter:"implicit"(default, exact) or"unrolled"Full theory derivation in Algorithm
See Changelog for details
Contents
- Quick Start
- API Reference
- Algorithm
- Benchmark
- Changelog
- TorchGW Optimization Log
- Summary
- Phase 1: GPU Sampling + Kernel Fusion Prep
- Phase 2: Dijkstra Caching
- Phase 3: Mixed Precision
- Phase 4: Early Stopping
- Phase 5: Sync Reduction
- Phase 6: Triton Fused Sinkhorn
- Phase 7: Sinkhorn Warm-Start
- Phase 8: Parallel Preprocessing
- Bug Fixes During Optimization
- Remaining Bottlenecks
- Future Directions
Installation
pip install torchgw
Or for development:
git clone https://github.com/chansigit/torchgw.git
cd torchgw && pip install -e ".[dev]"
Dependencies: numpy, scipy, scikit-learn, torch>=2.0, joblib.
Triton (ships with PyTorch 2.0+) enables GPU kernel fusion automatically.
Quick Example
from torchgw import sampled_gw
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
# T[i,j] = optimal coupling weight between source point i and target point j
Source Code & Links
GitHub repository: chansigit/torchgw
Issue tracker: GitHub Issues
Changelog: CHANGELOG.md
PyPI: coming soon
# Clone and install from source
git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e .
Citation
If you use TorchGW in your research, please cite:
@software{torchgw,
author = {Sijie Chen},
title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
url = {https://github.com/chansigit/torchgw},
version = {0.4.1},
year = {2026},
}
License
Free for academic and non-commercial use. Commercial use requires a separate license. See LICENSE and COMMERCIAL_LICENSE.md for details.
Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University. For commercial licensing inquiries, contact Stanford OTL: otl@stanford.edu