Machine Learning - JAX
Libraries
15_ENTRIES- Neural Network LibrariesFlax - Centered on flexibility and clarity. Flax NNX - An evolution on Flax by the same team Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind. Objax - Has an object oriented design similar to PyTorch. Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax. Trax - "Batteries included" deep learning library focused on providing solutions for common workloads. Jraph - Lightweight graph neural network library. Neural Tangents - High-level API for specifying neural networks of both finite and infinite width. HuggingFace Transformers - Ecosystem of pretr…
- Levanter
Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
- EasyLM
LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
- NumPyro
Probabilistic programming based on the Pyro library.
- Chex
Utilities to write and test reliable JAX code.
- Optax
Gradient processing and optimization library.
- RLax
Library for implementing reinforcement learning agents.
- JAX, M.D.
Accelerated, differential molecular dynamics.
- Coax
Turn RL papers into code, the easy way.
- Distrax
Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
- cvxpylayers
Construct differentiable convex optimization layers.
- TensorLy
Tensor learning made simple.
- NetKet
Machine Learning toolbox for Quantum Physics.
- Fortuna
AWS library for Uncertainty Quantification in Deep Learning.
- BlackJAX
Library of samplers for JAX.
- Dynamax
Probabilistic state space models.
New Libraries
72_ENTRIESThis section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
-
Neural Network LibrariesFedJAX - Federated learning in JAX, built on Optax and Haiku. Equivariant MLP - Construct equivariant neural network layers. jax-resnet - Implementations and checkpoints for ResNet variants in Flax. jax-raft - JAX/Flax port of the RAFT optical flow estimator. Parallax - Immutable Torch Modules for JAX.
-
Nonlinear OptimizationOptimistix - Root finding, minimisation, fixed points, and least squares. JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
-
Brain Dynamics Programming Ecosystem[BrainPy](https://github…
- jax-unirep
Library implementing the UniRep model for protein machine learning applications.
- flowjax
Distributions and normalizing flows built as equinox modules.
- flaxdiff
Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs)
- jax-flows
Normalizing flows in JAX.
- sklearn-jax-kernels
scikit-learnkernel matrices using JAX. - jax-cosmo
Differentiable cosmology library.
- efax
Exponential Families in JAX.
- mpi4jax
Combine MPI operations with your Jax code on CPUs and GPUs.
- imax
Image augmentations and transformations.
- FlaxVision
Flax version of TorchVision.
- Oryx
Probabilistic programming language based on program transformations.
- Optimal Transport Tools
Toolbox that bundles utilities to solve optimal transport problems.
- delta PV
A photovoltaic simulator with automatic differentation.
- jaxlie
Lie theory library for rigid body transformations and optimization.
- BRAX
Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
- flaxmodels
Pretrained models for Jax/Flax.
- CR.Sparse
XLA accelerated algorithms for sparse representations and compressive sensing.
- exojax
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
- PIX
PIX is an image processing library in JAX, for JAX.
- bayex
Bayesian Optimization powered by JAX.
- JaxDF
Framework for differentiable simulators with arbitrary discretizations.
- tree-math
Convert functions that operate on arrays into functions that operate on PyTrees.
- jax-models
Implementations of research papers originally without code or code written with frameworks other than JAX.
- PGMax
A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
- EvoJAX
Hardware-Accelerated Neuroevolution
- evosax
JAX-Based Evolution Strategies
- SymJAX
Symbolic CPU/GPU/TPU programming.
- mcx
Express & compile probabilistic programs for performant inference.
- Einshape
DSL-based reshaping library for JAX and other frameworks.
- ALX
Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
- Diffrax
Numerical differential equation solvers in JAX.
- tinygp
The tiniest of Gaussian process libraries in JAX.
- gymnax
Reinforcement Learning Environments with the well-known gym API.
- Mctx
Monte Carlo tree search algorithms in native JAX.
- KFAC-JAX
Second Order Optimization with Approximate Curvature for NNs.
- TF2JAX
Convert functions/graphs to JAX functions.
- jwave
A library for differentiable acoustic simulations
- GPJax
Gaussian processes in JAX.
- Jumanji
A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
- Eqxvision
Equinox version of Torchvision.
- JAXFit
Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
- econpizza
Solve macroeconomic models with hetereogeneous agents using JAX.
- SPU
A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
- jax-tqdm
Add a tqdm progress bar to JAX scans and loops.
- safejax
Serialize JAX, Flax, Haiku, or Objax model params with 🤗
safetensors. - Kernex
Differentiable stencil decorators in JAX.
- MaxText
A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
- Pax
A Jax-based machine learning framework for training large scale models.
- Praxis
The layer library for Pax with a goal to be usable by other JAX-based ML projects.
- purejaxrl
Vectorisable, end-to-end RL algorithms in JAX.
- Lorax
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
- SCICO
Scientific computational imaging in JAX.
- Spyx
Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
- OTT-JAX
Optimal transport tools in JAX.
- QDax
Quality Diversity optimization in Jax.
- JAX Toolbox
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
- Pgx
Vectorized board game environments for RL with an AlphaZero example.
- EasyDeL
EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
- XLB
A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.
- dynamiqs
High-performance and differentiable simulations of quantum systems with JAX.
- foragax
Agent-Based modelling framework in JAX.
- tmmax
Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research
- Coreax
Algorithms for finding coresets to compress large datasets while retaining their statistical properties.
- NAVIX
A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX
- FDTDX
Finite-Difference Time-Domain Electromagnetic Simulations in JAX
- DiffeRT
Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem.
- JAX-in-Cell
Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields
- kvax
A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism.
- astronomix
differentiable (magneto)hydrodynamics for astrophysics in JAX
- vivsim
Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method.
- MBIRJAX
High-performance tomographic reconstruction.
- torchax
torchax is a library for Jax to interoperate with model code written in PyTorch.
JAX
10_ENTRIES- Fourier Feature Networks
Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
- kalman-jax
Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
- jaxns
Nested sampling in JAX.
- Accurate Quantized Training
Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
- BNN-HMC
Implementation for the paper What Are Bayesian Neural Network Posteriors Really Like?.
- JAX-DFT
One-dimensional density functional theory (DFT) in JAX, with implementation of Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics.
- Robust Loss
Reference code for the paper A General and Adaptive Robust Loss Function.
- Symbolic Functionals
Demonstration from Evolving symbolic density functionals.
- TriMap
Official JAX implementation of TriMap: Large-scale Dimensionality Reduction Using Triplets.
Flax
38_ENTRIES- awesome-jax-flax-llms
Collection of LLMs implemented in JAX & Flax
- DeepSeek-R1-Flax-1.5B-Distill
Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.
- Performer
Flax implementation of the Performer (linear transformer via FAVOR+) architecture.
- JaxNeRF
Implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis with multi-device GPU/TPU support.
- mip-NeRF
Official implementation of Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.
- RegNeRF
Official implementation of RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs.
- Big Transfer (BiT)
Implementation of Big Transfer (BiT): General Visual Representation Learning.
- JAX RL
Implementations of reinforcement learning algorithms.
- gMLP
Implementation of Pay Attention to MLPs.
- MLP Mixer
Minimal implementation of MLP-Mixer: An all-MLP Architecture for Vision.
- Distributed Shampoo
Implementation of Second Order Optimization Made Practical.
- NesT
Official implementation of Aggregating Nested Transformers.
- XMC-GAN
Official implementation of Cross-Modal Contrastive Learning for Text-to-Image Generation.
- FNet
Official implementation of FNet: Mixing Tokens with Fourier Transforms.
- GFSA
Official implementation of Learning Graph Structure With A Finite-State Automaton Layer.
- IPA-GNN
Official implementation of Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks.
- Flax Models
Collection of models and methods implemented in Flax.
- Protein LM
Implements BERT and autoregressive models for proteins, as described in Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences and ProGen: Language Modeling for Protein Generation.
- Slot Attention
Reference implementation for Differentiable Patch Selection for Image Recognition.
- Vision Transformer
Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
- FID computation
Port of mseitzer/pytorch-fid to Flax.
- ARDM
Official implementation of Autoregressive Diffusion Models.
- D3PM
Official implementation of Structured Denoising Diffusion Models in Discrete State-Spaces.
- Gumbel-max Causal Mechanisms
Code for Learning Generalized Gumbel-max Causal Mechanisms, with extra code in GuyLor/gumbel_max_causal_gadgets_part2.
- Latent Programmer
Code for the ICML 2021 paper Latent Programmer: Discrete Latent Codes for Program Synthesis.
- SNeRG
Official implementation of Baking Neural Radiance Fields for Real-Time View Synthesis.
- Spin-weighted Spherical CNNs
Adaptation of Spin-Weighted Spherical CNNs.
- VDVAE
Adaptation of Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images, original code at openai/vdvae.
- MUSIQ
Checkpoints and model inference code for the ICCV 2021 paper MUSIQ: Multi-scale Image Quality Transformer
- AQuaDem
Official implementation of Continuous Control with Action Quantization from Demonstrations.
- Combiner
Official implementation of Combiner: Full Attention Transformer with Sparse Computation Cost.
- Dreamfields
Official implementation of the ICLR 2022 paper Progressive Distillation for Fast Sampling of Diffusion Models.
- GIFT
Official implementation of Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent.
- Light Field Neural Rendering
Official implementation of Light Field Neural Rendering.
- Sharpened Cosine Similarity in JAX by Raphael Pisoni
A JAX/Flax implementation of the Sharpened Cosine Similarity layer.
- GNNs for Solving Combinatorial Optimization Problems
A JAX + Flax implementation of Combinatorial Optimization with Physics-Inspired Graph Neural Networks.
- DETR
Flax implementation of DETR: End-to-end Object Detection with Transformers using Sinkhorn solver and parallel bipartite matching.
Haiku
12_ENTRIES- AlphaFold
Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.
- Bootstrap Your Own Latent
Implementation for the paper Bootstrap your own latent: A new approach to self-supervised Learning.
- Gated Linear Networks
GLNs are a family of backpropagation-free neural networks.
- Glassy Dynamics
Open source implementation of the paper Unveiling the predictive power of static structure in glassy systems.
- MMV
Code for the models in Self-Supervised MultiModal Versatile Networks.
- Normalizer-Free Networks
Official Haiku implementation of NFNets.
- NuX
Normalizing flows with JAX.
- OGB-LSC
This repository contains DeepMind's entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph) tracks of the OGB Large-Scale Challenge (OGB-LSC).
- Two Player Auction Learning
JAX implementation of the paper Auction learning as a two-player game.
- WikiGraphs
Baseline code to reproduce results in WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase.
Trax
1_ENTRIES- Reformer
Implementation of the Reformer (efficient transformer) architecture.
NumPyro
1_ENTRIES- lqg
Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper Putting perception into action with inverse optimal control for continuous psychophysics
Equinox
1_ENTRIES- Sampling Path Candidates with Machine Learning
Official tutorial and implementation from the paper Towards Generative Ray Path Sampling for Faster Point-to-Point Ray Tracing.
Videos
9_ENTRIES- NeurIPS 2020: JAX Ecosystem Meetup
JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
- Introduction to JAX
Simple neural network from scratch in JAX.
- JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas
JAX's core design, how it's powering new research, and how you can start using it.
- Bayesian Programming with JAX + NumPyro — Andy Kitchen
Introduction to Bayesian modelling using NumPyro.
- JAX: Accelerated machine-learning research via composable function transformations in Python | Neur…
JAX intro presentation in Program Transformations for Machine Learning workshop.
- JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury
Presentation of TPU host access with demo.
- Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020
Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.
- Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey
A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.
- JAX, Flax & Transformers 🤗
3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.
Papers
4_ENTRIESThis section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.
- **Compiling machine learning programs via high-level tracing**. Roy Frostig, Matthew James Johnson,…
White paper describing an early version of JAX, detailing how computation is traced and compiled.
- **JAX, M.D.: A Framework for Differentiable Physics**. Samuel S. Schoenholz, Ekin D. Cubuk. _NeurIP…
Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
- **Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization**. Pranav…
Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
- **XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python**. Mohammadmehdi Ata…
White paper describing the XLB library: benchmarks, validations, and more details about the library.
Tutorials and Blog Posts
26_ENTRIES- Using JAX to accelerate our research by David Budden and Matteo Hessel
Describes the state of JAX and the JAX ecosystem at DeepMind.
- Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange
Neural network building blocks from scratch with the basic JAX operators.
- Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh
A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.
- Tutorial: image classification with JAX and Flax Linen by 8bitmp3
Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
- Plugging Into JAX by Nick Doiron
Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.
- Meta-Learning in 50 Lines of JAX by Eric Jang
Introduction to both JAX and Meta-Learning.
- Normalizing Flows in 100 Lines of JAX by Eric Jang
Concise implementation of RealNVP.
- Differentiable Path Tracing on the GPU/TPU by Eric Jang
Tutorial on implementing path tracing.
- Ensemble networks by Mat Kelcey
Ensemble nets are a method of representing an ensemble of models as one single logical model.
- Out of distribution (OOD) detection by Mat Kelcey
Implements different methods for OOD detection.
- Understanding Autodiff with JAX by Srihari Radhakrishna
Understand how autodiff works using JAX.
- From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke
Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
- Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey
Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
- Evolving Neural Networks in JAX by Robert Tjarko Lange
Explores how JAX can power the next generation of scalable neuroevolution algorithms.
- Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz
Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.
- Deterministic ADVI in JAX by Martin Ingram
Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.
- Evolved channel selection by Mat Kelcey
Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.
- Introduction to JAX by Kevin Murphy
Colab that introduces various aspects of the language and applies them to simple ML problems.
- Writing an MCMC sampler in JAX by Jeremie Coullon
Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.
- How to add a progress bar to JAX scans and loops by Jeremie Coullon
Tutorial on how to add a progress bar to compiled loops in JAX using the
host_callbackmodule. - Get started with JAX by Aleksa Gordić
A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.
- Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit
A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.
- Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar
A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.
- Deep Learning tutorials with JAX+Flax by Phillip Lippe
A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.
- Achieving 4000x Speedups with PureJaxRL
A blog post on how JAX can massively speedup RL training through vectorisation.
- Simple PDE solver + Constrained Optimization with JAX by Philip Mocz
A simple example of solving the advection-diffusion equations with JAX and using it in a constrained optimization problem to find initial conditions that yield desired result.
Books
1_ENTRIES- Jax in Action
A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.