Skip to content

google/torchax

torchax: Running PyTorch on TPU via JAX

torchax is a backend for PyTorch that allows users to run PyTorch programs on Google Cloud TPUs. It also provides graph-level interoperability between PyTorch and JAX.

With torchax, you can:

  • Run PyTorch code on TPUs with minimal code changes.
  • Call JAX functions from PyTorch, passing in jax.Arrays.
  • Call PyTorch functions from JAX, passing in torch.Tensors.
  • Use JAX features like jax.grad, optax, and GSPMD to train PyTorch models.
  • Use a PyTorch model as a feature extractor with a JAX model.

Install

First, install the CPU version of PyTorch:

# On Linux
pip install torch --index-url https://download.pytorch.org/whl/cpu

# On Mac
pip install torch

Next, install JAX for your desired accelerator:

# On Google Cloud TPU
pip install -U jax[tpu]

# On GPU machines
pip install -U jax[cuda12]

# On Linux CPU machines or Macs (see the note below)
pip install -U jax

Note: For Apple devices, you can install the Metal version of JAX for hardware acceleration.

Finally, install torchax:

# Install from PyPI
pip install torchax

# Or, install torchax from source.
pip install git+https://github.com/google/torchax

Running a Model

To execute a model with torchax, start with any torch.nn.Module. Here’s an example with a simple 2-layer model:

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()

# Execute this model using torch.
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))

To execute this model with torchax, we need to enable torchax to capture PyTorch ops:

import torchax
torchax.enable_globally()

Then, we can use a jax device:

inputs = torch.randn(3, 3, 28, 28, device='jax')
m = MyModel().to('jax')
res = m(inputs)
print(type(res))  # outputs torchax.tensor.Tensor
print(res.jax()) # print the underlying Jax Array

torchax.tensor.Tensor is a torch.Tensor subclass that holds a jax.Array. You can inspect that JAX array with res.jax().

Although the code appears to be standard PyTorch, it's actually running on JAX.

How It Works

torchax uses a torch.Tensor subclass, torchax.tensor.Tensor, which holds a jax.Array and overrides the __torch_dispatch__ method. When a PyTorch operation is executed within the torchax environment (enabled by torchax.enable_globally()), the implementation of that operation is swapped with its JAX equivalent.

When a model is instantiated, tensor constructors like torch.rand create torchax.tensor.Tensor objects containing jax.Arrays. Subsequent operations extract the jax.Array, call the corresponding JAX implementation, and wrap the result back into a torchax.tensor.Tensor.

For more details, see the How It Works and Ops Registry documentation.

Executing with jax.jit

While torchax can run models in eager mode, jax.jit can be used for better performance. jax.jit is a decorator that compiles a function that takes and returns torch.Tensors into a faster, JAX-compiled version.

To use jax.jit, you first need a functional version of your model where parameters are passed as inputs:

def model_func(param, inputs):
  return torch.func.functional_call(m, param, inputs)

Here we use torch.func.functional_call from PyTorch to replace the model weights with param and then call the model. This is roughly equivalent to:

def model_func(param, inputs):
  m.load_state_dict(param)
  return m(*inputs)

Now, we can apply jax_jit on module_func:

from torchax.interop import jax_jit

model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))

See more examples at eager_mode.py and the examples folder.

To ease the idiom of creating functional model and calling it with parameters, we also created the JittableModule helper class. It lets us rewrite the above as:

from torchax.interop import JittableModule

m_jitted = JittableModule(m)
res = m_jitted(...)

The first time m_jitted is called, it will trigger jax.jit to compile the compile for the given input shapes. Subsequent calls with the same input shapes will be fast as the compilation is cached.

Saving and Loading Checkpoints

You can save and load your training state using torchax.save_checkpoint and torchax.load_checkpoint. The state can be a dictionary containing the model's weights, optimizer state, and any other relevant information.

import torchax
import torch
import optax

# Assume model, optimizer, and other states are defined
model = MyModel()
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(model.parameters())
weights = model.parameters()
buffers = model.buffers()
epoch = 10

state = {
    'weights': weights,
    'buffers': buffers,
    'opt_state': opt_state,
    'epoch': epoch,
}

# Save checkpoint
torchax.save_checkpoint(state, '/path/to/checkpoint.pt')

# Load checkpoint
loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt')

# Restore state
model.load_state_dict(loaded_state['weights'])
opt_state = loaded_state['opt_state']
epoch = loaded_state['epoch']

Citation

@software{torchax,
  author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
  title = {torchax: PyTorch on TPU and JAX interoperability},
  url = {https://github.com/pytorch/xla/tree/master/torchax}
  version = {0.0.4},
  date = {2025-02-24},
}

Maintainers & Contributors

This library is maintained by a team within Google Cloud. It has benefited from many contributions from both inside and outside the team.

Thank you to recent contributors.

Han Qi (qihqi), PyTorch/XLA
Manfei Bai (manfeibai), PyTorch/XLA
Will Cromar (will-cromar), Meta
Milad Mohammadi (miladm), PyTorch/XLA
Siyuan Liu (lsy323), PyTorch/XLA
Bhavya Bahl (bhavya01), PyTorch/XLA
Pei Zhang (zpcore), PyTorch/XLA
Yifei Teng (tengyifei), PyTorch/XLA
Chunnien Chan (chunnienc), Google, ODML
Alban Desmaison (albanD), Meta, PyTorch
Simon Teo (simonteozw), Google (20%)
David Huang (dvhg), Google (20%)
Barni Seetharaman (barney-s), Google (20%)
Anish Karthik (anishfish2), Google (20%)
Yao Gu (guyao), Google (20%)
Yenkai Wang (yenkwang), Google (20%)
Greg Shikhman (commander), Google (20%)
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
Tracy Chen (tracych477), Google (20%)
Matthias Guenther (mrguenther), Google (20%)
WenXin Dong (wenxindongwork), Google (20%)
Kevin Gleason (GleasonK), Google, StableHLO
Nupur Baghel (nupurbaghel), Google (20%)
Gwen Mittertreiner (gmittert), Google (20%)
Zeev Melumian (zmelumian), Lightricks
Vyom Sharma (vyom1611), Google (20%)
Shitong Wang (ShitongWang), Adobe
Rémi Doreau (ayshiff), Google (20%)
Lance Wang (wang2yn84), Google, CoreML
Hossein Sarshar (hosseinsarshar), Google (20%)
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
Tianqi Fan (tqfan28), Google (20%)
Jim Lin (jimlinntu), Google (20%)
Fanhai Lu (FanhaiLu1), Google Cloud
DeWitt Clinton (dewitt), Google PyTorch
Aman Gupta (aman2930), Google (20%)

A special thank you to @albanD for the initial inspiration for torchax.

About

No description, website, or topics provided.

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published