This is the official implementation of the paper Emergence of Exploration in Policy Gradient Reinforcement Learning via Retrying.
We argue that exploration matters because we are
- If no
$\color{#C44E52}{\text{uncertainty}}$ , the problem would reduce to pure optimization. - If no chance to
$\color{#4678C8}{\textbf{retry}}$ , only rational action is the current best.
We turn this intuition into an objective for RL, ReMax, where we assume
- bandit/: Code for illustrative bandit experiments.
- agents/: RL code for MinAtar, Atari, and Craftax.
Especially, all RL codes are implemented as single-file JAX, easy to understand and modify and fast.
Our method, ReMax PPO (RePPO) is implemented with the file name reppo.py at each environment directory.
Please make sure you have installed proper GPU compatible JAX in your environment.
uv syncFor Atari, for the compatibility to the envpool, we recommend to build the docker image with agents/atari/Dockerfile.
In bandit/, we implement the bandit experiments in the paper.
python plot_binary_bandit.py # Binary bandit plot (Figure 1 (left))
python plot_scaled_bernoulli_bandit.py # Bernoulli bandit plot (Figure 1 (center))
python plot_fixed_binary_bandit.py # Fixed binary bandit plot (Figure 1 (right))
python plot_bandit_with_posterior.py --family beta # for Beta-Bernoulli regret plot (Figure 2 (left))
python plot_bandit_with_posterior.py --family gaussian # for Gaussian-Gaussian regret plot (Figure 2 (right))In agents/, we implement the algorithms used in the paper.
minatar/: MinAtar experiments, using pgx implementation.atari/: Atari experiments (based on purejaxql).craftax/: Craftax experiments.
At sh/, run
./run_minatar.sh # for MinAtar
./run_atari.sh # for Atari
./run_craftax.sh # for Craftax