Skip to content

saucesaft/recurrl-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

recurrl-jax

Python JAX MuJoCo License: MIT

A JAX-based recurrent reinforcement learning library. Supports LSTM, GRU, and GTrXL (Gated Transformer-XL) sequence models with PPO and A2C. Built for use with MuJoCo/MJX environments and designed around asymmetric actor-critic training.

features

  • sequence models: MLP (non-recurrent), multi-layer LSTM, GRU, and GTrXL (experimental) with correct episode boundary resets
  • agents: PPO (with adaptive LR, minibatch BPTT) and A2C
  • asymmetric actor-critic: policy obs and privileged obs split at the model level
  • domain randomization: per-environment batched MJX model randomization
  • training: configurable via Hydra, Orbax checkpointing, and video recording

examples

LEAP hand reorientation

Dexterous in-hand reorientation with a 16-DOF LEAP hand in MJX.

cd examples/leap_hand
uv run python train.py

ANYmal locomotion

Quadrupedal locomotion for the ANYmal C robot.

cd examples/anymal
uv run python train.py

pendulum

Classic swing-up task.

cd examples/pendulum
uv run python train.py

Config lives in examples/leap_hand/config/.

library structure

recurrl_jax/
├── agents/          # PPO, A2C
├── model_fns/       # factory functions for actor, critic, repr, seq models
├── models/
│   ├── actor_critic.py
│   ├── rnns/        # LSTM, GRU (multi-layer)
│   └── transformers/  # GTrXL
├── trainers/        # Trainer, BaseTrainer
└── utils/           # wrappers, logging, video, quat math, running stats

background

recurrl-jax was originally developed during an internship at JRL Lab (CNRS-AIST JRL, UMI3218/RL), where it was used for dexterous manipulation research with the LEAP hand.

acknowledgments

  • MuJoCo Playground — Google DeepMind. The LEAP hand mesh assets and XML scene descriptions are sourced from this repository.
  • BRAX — Freeman et al., Google DeepMind. The batched MJX domain randomization approach follows patterns established in Brax.
  • subho406 — their Recurrent PPO with JAX was a reference for the recurrent PPO implementation.
  • GTrXL — Parisotto et al., "Stabilizing Transformers for Reinforcement Learning".
  • CleanRL — reference for the PPO implementation structure.

About

Highly customizable (and fast) recurrent RL library for JAX/MJX

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages