JAX is all you need 💜
This is a compact, functional implementation of gpt-2 using only jax and jax.numpy.
That's basically it.
- JAX functional style: immutable pytrees, pure functions, and explicit state threading. (will expand on this later)
- Education: did this project to better understand
jaxprogramming, implement gpt-2 from scratch, and code as clear and compact as possible so others can also see this as an educational resource.
- train with TPU
- optimize and experiment
- have fun and see where this goes