Skip to content

DQN wrapped Agents always use uniform replay buffer #230

@Davidb8

Description

@Davidb8

In the initialization code of the DQNAgent, line 386 has:

self._replay_scheme = 'uniform'

This initialization runs after the RainbowAgent initializes its replay_scheme

    vmax = float(vmax)
    self._num_atoms = num_atoms
    # If vmin is not specified, set it to -vmax similar to C51.
    vmin = vmin if vmin else -vmax
    self._support = jnp.linspace(vmin, vmax, num_atoms)
    self._replay_scheme = replay_scheme

    super(JaxRainbowAgent, self).__init__(
        num_actions=num_actions,
        observation_shape=observation_shape,
        observation_dtype=observation_dtype,
        stack_size=stack_size,
        network=functools.partial(network, num_atoms=num_atoms),
        gamma=gamma,
        update_horizon=update_horizon,
        min_replay_history=min_replay_history,
        update_period=update_period,
        target_update_period=target_update_period,
        epsilon_fn=epsilon_fn,
        epsilon_train=epsilon_train,
        epsilon_eval=epsilon_eval,
        epsilon_decay_period=epsilon_decay_period,
        optimizer=optimizer,
        seed=seed,
        summary_writer=summary_writer,
        summary_writing_frequency=summary_writing_frequency,
        allow_partial_reload=allow_partial_reload,
    )

Thus, even when gin configuring with a prioritized replay buffer, it overrides to be a uniform replay buffer.

This can be fixed by adjusted the DQNAgent code

to

if not hasattr(self, '_replay_scheme'):

      self._replay_scheme = 'uniform'

If this is indeed an issue and not an implementation error, I can make a PR to fix using this solution

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions