Companion code for the book Raw JAX by Elliot Arledge.
Every file runs standalone. No comments in the code -- the book explains everything.
uv init && uv add jax 'jax[cuda12]' numpyOr with pip:
pip install jax[cuda12] numpy| Dir | Topic | Files |
|---|---|---|
ch01/ |
NumPy, But on a GPU | scalars, jnp vs np, device placement, basic ops |
ch02/ |
Shapes and Reshaping | reshape, flatten, expand_dims, squeeze |
ch03/ |
Indexing and Immutability | .at[].set(), masking, jnp.where |
ch04/ |
Broadcasting and Einsum | broadcasting rules, gotchas, einsum, batched einsum |
ch05/ |
Transposing and Permuting | transpose, NCHW/NHWC, stack/concat/split |
ch06/ |
Random Numbers and Purity | PRNG keys, splitting, functional params |
ch07/ |
JIT Compilation and XLA | speedup, tracing, control flow |
ch08/ |
Automatic Differentiation | grad, value_and_grad, gradient descent, jacobian |
ch09/ |
vmap | basics, vmap+grad, nested vmap |
ch10/ |
Neural Network from Scratch | MLP, training loop |
ch11/ |
Pallas and CUDA | pallas add, fused multiply-add |
Every file has assertions. Run any file directly:
uv run python ch01/basic_operations.py
uv run python ch10/training_loop.py- Python 3.10+
- NVIDIA GPU with CUDA 12+
- JAX 0.9+
MIT