Warning
This repo has been merged into mlx-vis and is no longer maintained separately. Please use pip install mlx-vis instead.
t-SNE in pure MLX for Apple Silicon. Entire pipeline runs on Metal GPU.
12x faster than sklearn on 70K points. Uses FFT-accelerated interpolation (FIt-SNE) for O(n) repulsive force computation at scale.
animation.4.mp4
| tsne-mlx | umap-mlx | pacmap-mlx | |
|---|---|---|---|
| Algorithm | t-SNE (van der Maaten 2008) | UMAP (McInnes 2018) | PaCMAP (Wang 2021) |
| Speedup | 12x vs sklearn | 30x vs umap-learn | 13x vs PaCMAP |
| 70K time | 4.9s | 3.4s | 2.3s |
| Best for | Local cluster separation, publication-quality plots | General purpose, fast, preserves local + some global | Global structure, balanced local/global trade-off |
| Repulsive | FFT interpolation (N>16K) / exact compiled (N<16K) | Negative sampling SGD | Near/mid-near/far pair scheduling |
All three: pure MLX + numpy, no scipy/PyTorch, Metal GPU, pip install -e .
git clone https://github.com/hanxiao/tsne-mlx.git && cd tsne-mlx
uv venv .venv && source .venv/bin/activate
uv pip install -e .from tsne_mlx import TSNE
import numpy as np
X = np.random.randn(70000, 784).astype(np.float32)
Y = TSNE(n_components=2, perplexity=30, n_iter=750).fit_transform(X)Parameters:
n_components: output dimensions (default 2)perplexity: effective number of neighbors (default 30)learning_rate: gradient descent learning rate (default 200)n_iter: optimization iterations (default 1000)early_exaggeration: P multiplier during early phase (default 12)early_exaggeration_iter: early exaggeration duration (default 250)random_state: seed for reproducibilityverbose: print progress every N iterations (0 = silent)pca_dim: PCA preprocessing dimension (default 50, None to disable)epoch_callback: optionalcallable(epoch, Y_numpy)for animation snapshots
N iters sklearn tsne-mlx speedup
10000 1000 10.3s 6.2s 1.7x
70000 750 60.7s 4.9s 12x
For N < 16K, uses mx.compile-d exact repulsive forces via matmul on GPU.
For N >= 16K, switches to FFT-accelerated interpolation (FIt-SNE): Lagrange
polynomial scatter onto a regular grid, batched 2D FFT convolution with the
Cauchy kernel (circulant embedding), and interpolation back. This reduces
repulsive force computation from O(n^2) to O(n) per iteration, bringing
each epoch down to ~2ms on GPU for 70K points.
Fashion-MNIST 70K (784 dims, 10 classes, 750 iterations):
- PCA to 50 dims (when input dim > 50)
- Chunked brute-force KNN on Metal GPU (3 * perplexity neighbors)
- Vectorized binary search for per-point bandwidth (all N points on GPU)
- Symmetrized sparse P matrix via GPU argsort + searchsorted
- Gradient descent with momentum (0.5/0.8) and adaptive gains:
- Attractive: sparse KNN edges weighted by joint probability
- Repulsive (N < 16K): compiled exact all-pairs via matmul trick on GPU
- Repulsive (N >= 16K): FFT interpolation (FIt-SNE) on GPU
- PCA initialization (first 2 components, scaled to 1e-4)
The FFT approach follows Linderman et al. (2019).
Instead of computing all O(n^2) pairwise repulsive interactions, points are
interpolated onto a uniform grid using Lagrange polynomials (3 nodes per cell),
the kernel convolution is performed via FFT in O(M log M) where M is the grid
size (typically ~150-300), and results are interpolated back. The entire
scatter-FFT-gather pipeline runs on Metal GPU using mx.fft.rfft2.
Dependencies: mlx and numpy only. No scipy, no PyTorch.
The demo video above was generated by fashion_mnist_anim.py.
MIT