JAX
Simple, extensible implementations of some meta-learning algorithms in Jax
A collection of extensions for meta-learning in JAX
Model Agnostic Meta Learning (MAML) implemented in Flax, the neural network library for JAX.
Scenic: A Jax Library for Computer Vision Research and Beyond
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well a…
PIX is an image processing library in JAX, for JAX.
Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡
A differentiable cosmology library in JAX
v objective diffusion inference code for JAX.
Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes
Implementation of Model-Agnostic Meta-Learning (MAML) in Jax
Implementation of https://srush.github.io/annotated-s4
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's method.
JAXChem is a JAX-based deep learning library for complex and versatile chemical modeling
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Probabilistic Programming and Nested sampling in JAX
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations