Single-file minimal RL implementations in Flax NNX, inspired by minimalRL.
pip install -r requirements.txt # Python >= 3.12
wandb login # for experiment logging
# for running ppo_gtrxl_atari.py:
pip install ale_py gymnasium[ale]
# for running ppo_gtrxl_memorygym.py:
pip install git+https://github.com/jinPrelude/endless-memory-gym.git- minimal implementations for the LunarLander environment.
| Algorithm | Lines | Command | Training time (MacBook Air M2) | Environment |
|---|---|---|---|---|
| PPO | 228 | python ppo.py |
~40 sec | LunarLander-v3 |
| A2C | 180 | python a2c.py |
~100 sec | LunarLander-v3 |
| Impala (cleanba style) | 263 | python impala.py |
~100 sec | LunarLander-v3 |
- implementation files that build on the core algorithms with more advanced extensions, including:
- recurrent policy (LSTM)
- transformer model (TrXL, GTrXL)
- harder tasks (atari, memorygym)
| Algorithm | Lines | Command | Environment |
|---|---|---|---|
| PPO_LSTM | 278 | python ppo_lstm.py |
LunarLander-v3 |
| PPO_TrXL | 669 | python ppo_trxl.py |
LunarLander-v3 |
| PPO_GTrXL | 692 | python ppo_gtrxl.pypython ppo_gtrxl_atari.pypython ppo_gtrxl_memorygym.py |
LunarLander-v3 ALE/Breakout-v5 MemoryGym |
| Impala_LSTM | 294 | python impala_lstm.py |
LunarLander-v3 |
If you'd like to see a specific algorithm implemented, feel free to open an issue.
- Training failed with
gamma=0.97. Setting it to0.99was critical for learning. - Increasing hidden dim from 128 to 256 improved both convergence speed and final performance.
- For A2C, updating the actor with
Vinstead ofG - V(advantage) caused training to fail. - TrXL appears to be highly sensitive to hyperparameter tuning. For example, increasing
trxl_dimfrom 128 to 256 (andtrxl-num-headsfrom 2 to 4) caused training to fail. - In contrast, GTrXL was more stable and still trained well when increasing
trxl_dimto 256.