A JAX-based recurrent reinforcement learning library. Supports LSTM, GRU, and GTrXL (Gated Transformer-XL) sequence models with PPO and A2C. Built for use with MuJoCo/MJX environments and designed around asymmetric actor-critic training.
- sequence models: MLP (non-recurrent), multi-layer LSTM, GRU, and GTrXL (experimental) with correct episode boundary resets
- agents: PPO (with adaptive LR, minibatch BPTT) and A2C
- asymmetric actor-critic: policy obs and privileged obs split at the model level
- domain randomization: per-environment batched MJX model randomization
- training: configurable via Hydra, Orbax checkpointing, and video recording
Dexterous in-hand reorientation with a 16-DOF LEAP hand in MJX.
cd examples/leap_hand
uv run python train.pyQuadrupedal locomotion for the ANYmal C robot.
cd examples/anymal
uv run python train.pyClassic swing-up task.
cd examples/pendulum
uv run python train.pyConfig lives in examples/leap_hand/config/.
recurrl_jax/
├── agents/ # PPO, A2C
├── model_fns/ # factory functions for actor, critic, repr, seq models
├── models/
│ ├── actor_critic.py
│ ├── rnns/ # LSTM, GRU (multi-layer)
│ └── transformers/ # GTrXL
├── trainers/ # Trainer, BaseTrainer
└── utils/ # wrappers, logging, video, quat math, running stats
recurrl-jax was originally developed during an internship at JRL Lab (CNRS-AIST JRL, UMI3218/RL), where it was used for dexterous manipulation research with the LEAP hand.
- MuJoCo Playground — Google DeepMind. The LEAP hand mesh assets and XML scene descriptions are sourced from this repository.
- BRAX — Freeman et al., Google DeepMind. The batched MJX domain randomization approach follows patterns established in Brax.
- subho406 — their Recurrent PPO with JAX was a reference for the recurrent PPO implementation.
- GTrXL — Parisotto et al., "Stabilizing Transformers for Reinforcement Learning".
- CleanRL — reference for the PPO implementation structure.