Skip to content

Infatoshi/raw-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Raw JAX

Companion code for the book Raw JAX by Elliot Arledge.

Every file runs standalone. No comments in the code -- the book explains everything.

Setup

uv init && uv add jax 'jax[cuda12]' numpy

Or with pip:

pip install jax[cuda12] numpy

Chapters

Dir Topic Files
ch01/ NumPy, But on a GPU scalars, jnp vs np, device placement, basic ops
ch02/ Shapes and Reshaping reshape, flatten, expand_dims, squeeze
ch03/ Indexing and Immutability .at[].set(), masking, jnp.where
ch04/ Broadcasting and Einsum broadcasting rules, gotchas, einsum, batched einsum
ch05/ Transposing and Permuting transpose, NCHW/NHWC, stack/concat/split
ch06/ Random Numbers and Purity PRNG keys, splitting, functional params
ch07/ JIT Compilation and XLA speedup, tracing, control flow
ch08/ Automatic Differentiation grad, value_and_grad, gradient descent, jacobian
ch09/ vmap basics, vmap+grad, nested vmap
ch10/ Neural Network from Scratch MLP, training loop
ch11/ Pallas and CUDA pallas add, fused multiply-add

Running

Every file has assertions. Run any file directly:

uv run python ch01/basic_operations.py
uv run python ch10/training_loop.py

Requirements

  • Python 3.10+
  • NVIDIA GPU with CUDA 12+
  • JAX 0.9+

License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages