Skip to content

arman-hk/gpt-2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

gpt-2

JAX is all you need 💜

This is a compact, functional implementation of gpt-2 using only jax and jax.numpy.

Structure:

That's basically it.

Ideas and phil:

  • JAX functional style: immutable pytrees, pure functions, and explicit state threading. (will expand on this later)
  • Education: did this project to better understand jax programming, implement gpt-2 from scratch, and code as clear and compact as possible so others can also see this as an educational resource.

WIP:

  • train with TPU
  • optimize and experiment
  • have fun and see where this goes

About

A gpt-2-small in pure JAX ˆᵕˆ

Topics

Resources

Stars

Watchers

Forks

Languages