CLI tool for running JAX training on Google Cloud Spot TPUs with automatic preemption recovery. Provisions TPUs, uploads code, runs training, and seamlessly retries when Spot instances get preempted.
pip install spotaxRequires Python 3.10+ and gcloud CLI.
# Verify prerequisites
spotax setup
# Auto-fix issues (SSH keys, OS Login)
spotax setup --fixspotax run train.py --tpu v5litepod-1 --zone us-central1-aSpotJAX will:
- Create a GCS bucket for checkpoints
- Provision a Spot TPU
- SSH into all nodes and upload your code via rsync
- Run
spotax_setup.shif present (custom pre-install steps) - Install
requirements.txtdependencies - Run your script with checkpoint/distributed env vars injected
- On preemption: clean up, provision a new TPU, and resume from last checkpoint
your-project/
train.py # Your training script
data.py # Data loading (optional)
spotax_utils.py # Checkpoint & distributed utilities (copy from examples/)
requirements.txt # Dependencies installed on TPU VMs
spotax_setup.sh # Pre-install script (optional)
SpotJAX handles infrastructure recovery automatically — on preemption it provisions a new TPU and reruns your script. But without checkpointing, your training would restart from step 0 every time. spotax_utils.py bridges this gap: it saves model state to GCS and restores it on retry, so training resumes from where it left off.
Copy this file from examples/ into your project. It provides checkpoint management and distributed training setup with no runtime dependency on the spotax package.
from spotax_utils import CheckpointManager, get_config, setup_distributed
config = get_config()
setup_distributed() # Initialize JAX distributed runtime
ckpt = CheckpointManager(config.checkpoint_dir, save_interval_steps=1000)
state, start_step = ckpt.restore_or_init(initial_state)
for step in range(start_step, max_steps):
state = train_step(state, batch)
ckpt.save(step, state)
if ckpt.reached_preemption(step):
break # Orbax already saved checkpoint, orchestrator will retry
ckpt.close()How checkpointing works:
- SpotJAX enables GCP's autocheckpoint. On preemption, GCP sends SIGTERM to the VM.
- Orbax catches SIGTERM and saves a checkpoint automatically, even outside
save_interval_steps. reached_preemption()detects this across all hosts and returnsTrueso your script exits cleanly.- The orchestrator then provisions a new TPU and reruns.
restore_or_init()picks up from the last checkpoint.
Standard pip requirements. SpotJAX installs them on each TPU VM using uv with the JAX TPU releases index. Include jax[tpu] and any other dependencies your script needs:
jax[tpu]
flax
optax
orbax-checkpoint
grain
Runs before requirements.txt installation. Use it for things pip can't handle: system packages, building from source, patching libraries. The venv is already activated when this runs.
SpotJAX injects these into your training script (read them via get_config()):
| Variable | Description |
|---|---|
SPOT_CHECKPOINT_DIR |
GCS path for checkpoints (gs://bucket/job-id/ckpt) |
SPOT_LOG_DIR |
GCS path for logs |
SPOT_JOB_ID |
Unique job identifier |
SPOT_IS_RESTART |
"true" if resuming after preemption |
JAX auto-discovers TPU topology (coordinator address, process count, process ID) from TPU metadata — no manual configuration needed for multi-node.
spotax run <script> [options]| Option | Default | Description |
|---|---|---|
--tpu, -t |
v5litepod-1 |
TPU type |
--zone, -z |
us-central1-a |
GCP zone |
--project, -p |
auto-detect | GCP project ID |
--bucket, -b |
spotax-{project} |
GCS bucket for checkpoints |
--name, -n |
timestamp | Job name |
--max-retries |
5 |
Max restart attempts |
--stream-worker, -w |
0 |
Worker index to stream logs from |
--code-dir, -c |
script's parent dir | Directory to upload |
- ImageNet EfficientNet - Train EfficientNet-B2 on ImageNet-1K with ArrayRecord data pipeline
- MaxText Qwen3 SFT - Fine-tune Qwen3 on GSM8K math problems using MaxText
- Python 3.10+
- GCP project with TPU API enabled and Spot TPU quota
- gcloud CLI authenticated with Application Default Credentials
Apache 2.0