Skip to content

Conversation

@jinPrelude
Copy link

Issue #79

environment made by RolloutWrapper doesn't reflect num_env_step variable which we put into RolloutWrapper:

Code for reproduction

from gymnax.experimental import RolloutWrapper
import jax

ENV_NUM = 3
manager = RolloutWrapper(None, env_name='CartPole-v1', num_env_steps=100)

rng, rollout_rng = jax.random.split(jax.random.key(0))
rollout_rng = jax.random.split(rollout_rng, ENV_NUM)
obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(rollout_rng, None)
print(done.shape) # it should print (3, 100), but the result is (3, 500)

Why this bug happened?

  1. RolloutWrapper.single_rollout() puts self.env_params.max_steps_in_episode in jax.lax.scan() instead of self.num_env_steps (gymnax/experimental/rollout.py#L94).
  2. Test code uses num_env_steps as sames as enviornment's max_steps_in_episode for testing this feature (gymnax/tests/wrappers/test_evaluator.py).

What is fixed in this PR?

  1. Fix self.env_params.max_steps_in_episode to self.num_env_steps in RolloutWrapper.single_rollout() > jax.lax.scan().
  2. Fix test code to put different value(200 -> 150) for proper test.

@jinPrelude jinPrelude changed the title Fix num_env_step not working in RolloutWrapper issue Fix num_env_step not working in RolloutWrapper issue (#79) Jun 23, 2024
@jinPrelude jinPrelude changed the title Fix num_env_step not working in RolloutWrapper issue (#79) Fix num_env_step not working in RolloutWrapper issue Jun 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant