Skip to content

jinPrelude/minimal-rl-nnx

Repository files navigation

minimal-rl-nnx

Single-file minimal RL implementations in Flax NNX, inspired by minimalRL.

Quick Start

pip install -r requirements.txt   # Python >= 3.12
wandb login                       # for experiment logging

# for running ppo_gtrxl_atari.py:
pip install ale_py gymnasium[ale]

# for running ppo_gtrxl_memorygym.py:
pip install git+https://github.com/jinPrelude/endless-memory-gym.git

Core Algorithms

  • minimal implementations for the LunarLander environment.
Algorithm Lines Command Training time (MacBook Air M2) Environment
PPO 228 python ppo.py ~40 sec LunarLander-v3
A2C 180 python a2c.py ~100 sec LunarLander-v3
Impala (cleanba style) 263 python impala.py ~100 sec LunarLander-v3

Advanced Implementations

  • implementation files that build on the core algorithms with more advanced extensions, including:
    • recurrent policy (LSTM)
    • transformer model (TrXL, GTrXL)
    • harder tasks (atari, memorygym)
Algorithm Lines Command Environment
PPO_LSTM 278 python ppo_lstm.py LunarLander-v3
PPO_TrXL 669 python ppo_trxl.py LunarLander-v3
PPO_GTrXL 692 python ppo_gtrxl.py
python ppo_gtrxl_atari.py
python ppo_gtrxl_memorygym.py
LunarLander-v3
ALE/Breakout-v5
MemoryGym
Impala_LSTM 294 python impala_lstm.py LunarLander-v3

If you'd like to see a specific algorithm implemented, feel free to open an issue.

Tuning Tips

  • Training failed with gamma=0.97. Setting it to 0.99 was critical for learning.
  • Increasing hidden dim from 128 to 256 improved both convergence speed and final performance.
  • For A2C, updating the actor with V instead of G - V (advantage) caused training to fail.
  • TrXL appears to be highly sensitive to hyperparameter tuning. For example, increasing trxl_dim from 128 to 256 (and trxl-num-heads from 2 to 4) caused training to fail.
  • In contrast, GTrXL was more stable and still trained well when increasing trxl_dim to 256.

References

  • Heavily Inspired by the philosophy of the minimalrl repository.
  • The Impala implementation closely follows cleanba, with the main change being a migration from Flax Linen to Flax NNX. Their Impala design is outstanding - huge thanks to their codebase!

Performance graph

About

Single-file minimal RL algorithms using latest Flax nnx

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages