The examples from this directory.
Each example is designed to be self-contained and easily forkable, while reproducing relevant results in different areas of machine learning.
As discussed in #231, we decided to go for a standard pattern for all examples including the simplest ones (like MNIST). This makes every example a bit more verbose, but once you know one example, you know the structure of all of them. Having unit tests and integration tests is also very useful when you fork these examples.
Some of the examples below have a link "🕹Interactive🕹" that lets you run them directly in Colab.
Image classification
-
MNIST - 🕹Interactive🕹: Convolutional neural network for MNIST classification (featuring simple code).
-
ImageNet - 🕹Interactive🕹: Resnet-50 on ImageNet with weight decay (featuring multi host SPMD, custom preprocessing, checkpointing, dynamic scaling, mixed precision).
Reinforcement learning
- Proximal Policy Optimization: Learning to play Atari games (featuring single host SPMD, RL setup).
Natural language processing
- Sequence to sequence for number addition: (featuring simple code, LSTM state handling, on the fly data generation).
- Parts-of-speech tagging: Simple transformer encoder model using the universal dependency dataset.
- Sentiment classification: with a LSTM model.
- Transformer encoder/decoder model trained on WMT: Translating English/German (featuring multihost SPMD, dynamic bucketing, attention cache, packed sequences, recipe for TPU training on GCP).
- Transformer encoder trained on one billion word benchmark: for autoregressive language modeling, based on the WMT example above.
Generative models
- Variational auto-encoder: Trained on binarized MNIST (featuring simple code, vmap).
Graph modeling
- Graph Neural Networks: Molecular predictions on ogbg-molpcba from the Open Graph Benchmark.
The following code bases use Flax and provide training frameworks and a wealth of examples, in many cases with pre-trained weights:
-
HuggingFace Transformers is a very popular library for building, training, and deploying state of the art machine learning models. These models can be applied on text, images, and audio. After organizing the JAX/Flax community week, they have now over 5,000 Flax/JAX models in their repository.
-
Scenic is a codebase/library for computer vision research and beyond. Scenic's main focus is around attention-based models. Scenic has been successfully used to develop classification, segmentation, and detection models for multiple modalities including images, video, audio, and multimodal combinations of them.
-
Big Vision is a codebase designed for training large-scale vision models using Cloud TPU VMs or GPU machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow Datasets for scalable and reproducible input pipelines. This is the original codebase of ViT, MLP-Mixer, LiT, UViM, and many more models.
-
T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.
In addition to the curated list of official Flax examples, there is a growing community of people using Flax to build new types of machine learning models. We are happy to showcase any example built by the community here! If you want to submit your own example, we suggest that you start by forking one of the official Flax example, and start from there.
| Link | Author | Task type | Reference |
|---|---|---|---|
| matthias-wright/flaxmodels | @matthias-wright | Various | GPT-2, ResNet, StyleGAN-2, VGG, ... |
| DarshanDeshpande/jax-models | @DarshanDeshpande | Various | Segformer, Swin Transformer, ... also some stand-alone layers |
| google/vision_transformer | @andsteing | Image classification, image/text | https://arxiv.org/abs/2010.11929, https://arxiv.org/abs/2105.01601, https://arxiv.org/abs/2111.07991, ... |
| JAX-RL | @henry-prior | Reinforcement learning | N/A |
| DCGAN Colab | @bkkaggle | Image Synthesis | https://arxiv.org/abs/1511.06434 |
| BigBird Fine-tuning | @vasudevgupta7 | Question-Answering | https://arxiv.org/abs/2007.14062 |
| jax-resnet | @n2cholas | Various resnet implementations | torch.hub |
Most of our examples in this directory follow a structure that we found to work well with Flax projects, and we strive to make the examples easy to explore and easy to fork. In particular (taken from #231)
- README: contains links to paper, command line, TensorBoard metrics
- Focus: an example is about a single model/dataset
- Configs: we use
ml_collections.ConfigDictstored underconfigs/ - Tests: executable
main.pyloadstrain.pywhich hastrain_test.py - Data: is read from [TensorFlow Datasets]
- Standalone: every directory is self-conained
- Requirements: are pinned in
requirements.txt - Boilerplate: is reduced by using
clu - Interactive: the example can be explored with a Colab