TorchGW – Fast Sampled Gromov-Wasserstein Optimal Transport

TorchGW — Fast Gromov-Wasserstein optimal transport solver in PyTorch

GitHub Version License Python PyTorch

Note

Source code: github.com/chansigit/torchgw — clone, star, or open issues on GitHub.

TorchGW is a scalable solver for Gromov-Wasserstein optimal transport, implemented in pure PyTorch with GPU-accelerated Triton fused Sinkhorn kernels.

It aligns two point clouds by matching their internal distance structures — even when the point clouds live in different dimensions — making it ideal for manifold alignment, single-cell multi-omics integration, and cross-domain graph matching.

Key features:

  • Up to 175x faster than POT on typical workloads (spiral 4000×5000: 1s vs 183s)

  • Triton fused Sinkhorn — single-pass online logsumexp, no intermediate N×K matrices

  • Mixed precision — float32 Sinkhorn + float64 output, zero quality loss

  • Smart early stopping — cost plateau detection, not just transport plan norm

  • Differentiable — exact gradients via implicit differentiation at the Sinkhorn fixed point

  • No POT dependency at runtime — pure PyTorch + scipy + scikit-learn

What’s New in v0.4.1

  • Exact differentiable gradients via implicit differentiation at the Sinkhorn fixed point — fixes a correctness bug where the old backward produced gradients with up to 30x error

  • New grad_mode parameter: "implicit" (default, exact) or "unrolled"

  • Full theory derivation in Algorithm

  • See Changelog for details

Installation

pip install torchgw

Or for development:

git clone https://github.com/chansigit/torchgw.git
cd torchgw && pip install -e ".[dev]"

Dependencies: numpy, scipy, scikit-learn, torch>=2.0, joblib. Triton (ships with PyTorch 2.0+) enables GPU kernel fusion automatically.

Quick Example

from torchgw import sampled_gw

T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
# T[i,j] = optimal coupling weight between source point i and target point j

Citation

If you use TorchGW in your research, please cite:

@software{torchgw,
  author = {Sijie Chen},
  title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
  url = {https://github.com/chansigit/torchgw},
  version = {0.4.1},
  year = {2026},
}

License

Free for academic and non-commercial use. Commercial use requires a separate license. See LICENSE and COMMERCIAL_LICENSE.md for details.

Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University. For commercial licensing inquiries, contact Stanford OTL: otl@stanford.edu