Stars
TPU inference for vLLM, with unified JAX and PyTorch support.
A simple, performant and scalable Jax LLM!
Home for "How To Scale Your Model", a short blog-style textbook about scaling LLMs on TPUs
Minimal yet performant LLM examples in pure JAX
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
Optax is a gradient processing and optimization library for JAX.
Orbax provides common checkpointing and persistence utilities for JAX users
Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Flax is a neural network library for JAX that is designed for flexibility.
Simple next-token-prediction for RLHF
A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry lead…
Models and examples built with TensorFlow
Reference models and tools for Cloud TPUs.
Testing framework for Deep Learning models (Tensorflow and PyTorch) on Google Cloud hardware accelerators (TPU and GPU)
Notebooks for learning deep learning