Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def main(_: Any) -> None:
exploration_scheduler_fn=LinearExplorationScheduler,
epsilon_min=0.05,
epsilon_decay=5e-4,
importance_sampling_exponent=0.2,
optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
).build()
Expand Down
1 change: 1 addition & 0 deletions examples/flatland/feedforward/decentralised/run_madqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def main(_: Any) -> None:
exploration_scheduler_fn=LinearExplorationScheduler,
epsilon_min=0.05,
epsilon_decay=1e-4,
importance_sampling_exponent=0.2,
optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
).build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def main(_: Any) -> None:
exploration_scheduler_fn=LinearExplorationScheduler,
epsilon_min=0.05,
epsilon_decay=1e-4,
importance_sampling_exponent=0.2,
optimizer=snt.optimizers.Adam(learning_rate=1e-4),
checkpoint_subpath=checkpoint_dir,
).build()
Expand Down
1 change: 1 addition & 0 deletions examples/smac/feedforward/decentralised/run_madqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def main(_: Any) -> None:
exploration_scheduler_fn=LinearExplorationScheduler,
epsilon_min=0.05,
epsilon_decay=1e-5,
importance_sampling_exponent=0.2,
optimizer=snt.optimizers.SGD(learning_rate=1e-2),
checkpoint_subpath=checkpoint_dir,
batch_size=512,
Expand Down
17 changes: 14 additions & 3 deletions mava/systems/tf/madqn/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from mava.components.tf.modules.stabilising import FingerPrintStabalisation
from mava.systems.tf import executors
from mava.systems.tf.madqn import execution, training
from mava.wrappers import DetailedTrainerStatisticsWithEpsilon
from mava.wrappers import MADQNDetailedTrainerStatistics


@dataclasses.dataclass
Expand Down Expand Up @@ -81,6 +81,8 @@ class MADQNConfig:
prefetch_size: int
batch_size: int
n_step: int
max_priority_weight: float
importance_sampling_exponent: Optional[float]
sequence_length: int
period: int
discount: float
Expand Down Expand Up @@ -181,9 +183,17 @@ def make_replay_tables(
error_buffer=error_buffer,
)

# Maybe use prioritized sampling.
if self._config.importance_sampling_exponent is not None:
sampler = reverb.selectors.Prioritized(
self._config.importance_sampling_exponent
)
else:
sampler = reverb.selectors.Uniform()

replay_table = reverb.Table(
name=self._config.replay_table_name,
sampler=reverb.selectors.Uniform(),
sampler=sampler,
remover=reverb.selectors.Fifo(),
max_size=self._config.max_replay_size,
rate_limiter=limiter,
Expand Down Expand Up @@ -335,6 +345,7 @@ def make_trainer(
counter: Optional[counting.Counter] = None,
logger: Optional[types.NestedLogger] = None,
communication_module: Optional[BaseCommunicationModule] = None,
replay_client: Optional[reverb.TFClient] = None,
) -> core.Trainer:
"""Create a trainer instance.

Expand Down Expand Up @@ -390,6 +401,6 @@ def make_trainer(
checkpoint_subpath=self._config.checkpoint_subpath,
)

trainer = DetailedTrainerStatisticsWithEpsilon(trainer) # type:ignore
trainer = MADQNDetailedTrainerStatistics(trainer) # type:ignore

return trainer
11 changes: 11 additions & 0 deletions mava/systems/tf/madqn/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(
samples_per_insert: Optional[float] = 32.0,
n_step: int = 5,
sequence_length: int = 20,
importance_sampling_exponent: Optional[float] = None,
max_priority_weight: float = 0.9,
period: int = 20,
max_gradient_norm: float = None,
discount: float = 0.99,
Expand Down Expand Up @@ -144,6 +146,11 @@ def __init__(
Defaults to 5.
sequence_length (int, optional): recurrent sequence rollout length. Defaults
to 20.
importance_sampling_exponent (float, optional): value of importance sampling
exponent (usually around 0.2). If None, importance sampling is not used.
max_priority_weight (float): Required if importance_sampling_exponent
is not None. Defaults to 0.9. Used to scale the maximum priority of
reverb samples.
period (int, optional): consecutive starting points for overlapping
rollouts across a sequence. Defaults to 20.
max_gradient_norm (float, optional): maximum allowed norm for gradients
Expand Down Expand Up @@ -205,6 +212,7 @@ def __init__(
}
self._num_exectors = num_executors
self._num_caches = num_caches

self._max_executor_steps = max_executor_steps
self._checkpoint_subpath = checkpoint_subpath
self._checkpoint = checkpoint
Expand Down Expand Up @@ -235,6 +243,8 @@ def __init__(
samples_per_insert=samples_per_insert,
n_step=n_step,
sequence_length=sequence_length,
importance_sampling_exponent=importance_sampling_exponent,
max_priority_weight=max_priority_weight,
period=period,
max_gradient_norm=max_gradient_norm,
checkpoint=checkpoint,
Expand Down Expand Up @@ -383,6 +393,7 @@ def trainer(
return self._builder.make_trainer(
networks=system_networks,
dataset=dataset,
replay_client=replay,
counter=counter,
communication_module=communication_module,
logger=trainer_logger,
Expand Down
112 changes: 102 additions & 10 deletions mava/systems/tf/madqn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from acme.utils import counting, loggers

import mava
from mava.adders import reverb as reverb_adders
from mava import types as mava_types
from mava.components.tf.modules.communication import BaseCommunicationModule
from mava.components.tf.modules.exploration.exploration_scheduling import (
Expand Down Expand Up @@ -60,11 +61,15 @@ def __init__(
agent_net_keys: Dict[str, str],
exploration_scheduler: LinearExplorationScheduler,
max_gradient_norm: float = None,
importance_sampling_exponent: Optional[float] = None,
replay_client: Optional[reverb.TFClient] = None,
max_priority_weight: float = 0.9,
fingerprint: bool = False,
counter: counting.Counter = None,
logger: loggers.Logger = None,
checkpoint: bool = True,
checkpoint_subpath: str = "~/mava/",
replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE,
communication_module: Optional[BaseCommunicationModule] = None,
):
"""Initialise MADQN trainer
Expand Down Expand Up @@ -131,6 +136,21 @@ def __init__(
# Store the exploration scheduler
self._exploration_scheduler = exploration_scheduler

# Importance sampling hyper-parameters
self._max_priority_weight = max_priority_weight
self._importance_sampling_exponent = importance_sampling_exponent

# Replay client for updating priorities.
self._replay_client = replay_client
self._replay_table_name = replay_table_name

# NOTE We make replay_client optional to make changes to MADQN trainer
# compatible with the other systems that inherit from it (VDN, QMIX etc.)
# TODO Include importance sampling in the other systems so that we can remove
# this check.
if self._importance_sampling_exponent is not None:
assert isinstance(self._replay_client, reverb.Client)

# Dictionary with network keys for each agent.
self.unique_net_keys = set(self._agent_net_keys.values())

Expand Down Expand Up @@ -221,6 +241,23 @@ def _update_target_networks(self) -> None:
dest.assign(src)
self._num_steps.assign_add(1)

def _update_sample_priorities(self, keys: tf.Tensor, priorities: tf.Tensor) -> None:
"""Update sample priorities in replay table using importance weights.

Args:
keys (tf.Tensor): Keys of the replay samples.
priorities (tf.Tensor): New priorities for replay samples.
"""
# Maybe update the sample priorities in the replay buffer.
if (
self._importance_sampling_exponent is not None
and self._replay_client is not None
):
self._replay_client.mutate_priorities(
table=self._replay_table_name,
updates=dict(zip(keys.numpy(), priorities.numpy())),
)

def _get_feed(
self,
o_tm1_trans: Dict[str, np.ndarray],
Expand Down Expand Up @@ -280,12 +317,7 @@ def step(self) -> None:
self._logger.write(fetches)

@tf.function
def _step(self) -> Dict[str, Dict[str, Any]]:
"""Trainer forward and backward passes."""

# Update the target networks
self._update_target_networks()

def _forward_backward(self) -> Tuple:
# Get data from replay (dropping extras if any). Note there is no
# extra data here because we do not insert any into Reverb.
inputs = next(self._iterator)
Expand All @@ -294,8 +326,34 @@ def _step(self) -> Dict[str, Dict[str, Any]]:

self._backward()

# Log losses per agent
return self._q_network_losses
extras = {}

if self._importance_sampling_exponent is not None:
extras.update(
{"keys": self._sample_keys, "priorities": self._sample_priorities}
)

# Return Q-value losses.
fetches = self._q_network_losses

return fetches, extras

def _step(self) -> Dict:
"""Trainer forward and backward passes."""

# Update the target networks
self._update_target_networks()

fetches, extras = self._forward_backward()

# Maybe update priorities.
# NOTE _update_sample_priorities must happen outside of
# tf.function. That is why we seperate out forward_backward().
if self._importance_sampling_exponent is not None:
self._update_sample_priorities(extras["keys"], extras["priorities"])

# Log losses and epsilon
return fetches

def _forward(self, inputs: reverb.ReplaySample) -> None:
"""Trainer forward pass
Expand All @@ -304,6 +362,14 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
inputs (Any): input data from the data table (transitions)
"""

# Get info about the samples from reverb.
sample_info = inputs.info
sample_keys = tf.transpose(inputs.info.key)
sample_probs = tf.transpose(sample_info.probability)

# Initialize sample priorities at zero.
sample_priorities = np.zeros(len(inputs.info.key))

# Unpack input data as follows:
# o_tm1 = dictionary of observations one for each agent
# a_tm1 = dictionary of actions taken from obs in o_tm1
Expand Down Expand Up @@ -361,7 +427,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
q_t_selector = self._q_networks[agent_key](o_t_feed)

# Q-network learning
loss, _ = trfl.double_qlearning(
loss, loss_extras = trfl.double_qlearning(
q_tm1,
a_tm1_feed,
r_t[agent],
Expand All @@ -370,12 +436,38 @@ def _forward(self, inputs: reverb.ReplaySample) -> None:
q_t_selector,
)

loss = tf.reduce_mean(loss)
# Maybe do importance sampling.
if self._importance_sampling_exponent is not None:
importance_weights = 1.0 / sample_probs # [B]
importance_weights **= self._importance_sampling_exponent
importance_weights /= tf.reduce_max(importance_weights)

# Reweight loss.
loss *= tf.cast(importance_weights, loss.dtype) # [B]

# Update priorities.
errors = loss_extras.td_error
abs_errors = tf.abs(errors)
mean_priority = tf.reduce_mean(abs_errors, axis=0)
max_priority = tf.reduce_max(abs_errors, axis=0)
sample_priorities += (
self._max_priority_weight * max_priority
+ (1 - self._max_priority_weight) * mean_priority
)

loss = tf.reduce_mean(loss)
q_network_losses[agent] = {"q_value_loss": loss}

# Store losses and tape
self._q_network_losses = q_network_losses
self.tape = tape

# Store sample keys and priorities
self._sample_keys = sample_keys
self._sample_priorities = sample_priorities / len(
self._agents
) # averaged over agents.

def _backward(self) -> None:
"""Trainer backward pass updating network parameters"""

Expand Down
7 changes: 5 additions & 2 deletions mava/systems/tf/qmix/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from mava.components.tf.modules.stabilising import FingerPrintStabalisation
from mava.systems.tf.madqn.builder import MADQNBuilder, MADQNConfig
from mava.systems.tf.qmix import execution, training
from mava.wrappers import DetailedTrainerStatisticsWithEpsilon
from mava.wrappers import MADQNDetailedTrainerStatistics


@dataclasses.dataclass
Expand Down Expand Up @@ -119,6 +119,7 @@ def make_trainer(
counter: Optional[counting.Counter] = None,
logger: Optional[types.NestedLogger] = None,
communication_module: Optional[BaseCommunicationModule] = None,
replay_client: Optional[reverb.TFClient] = None,
) -> core.Trainer:
"""Create a trainer instance.

Expand All @@ -132,6 +133,8 @@ def make_trainer(
metadata.. Defaults to None.
communication_module (BaseCommunicationModule): module to enable
agent communication. Defaults to None.
replay_client (reverb.TFClient): Used for importance sampling.
Not implemented yet.

Returns:
core.Trainer: system trainer, that uses the collected data from the
Expand Down Expand Up @@ -179,6 +182,6 @@ def make_trainer(
checkpoint_subpath=self._config.checkpoint_subpath,
)

trainer = DetailedTrainerStatisticsWithEpsilon(trainer) # type:ignore
trainer = MADQNDetailedTrainerStatistics(trainer) # type:ignore

return trainer
4 changes: 4 additions & 0 deletions mava/systems/tf/qmix/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(
samples_per_insert: Optional[float] = 32.0,
n_step: int = 5,
sequence_length: int = 20,
importance_sampling_exponent: Optional[float] = None,
max_priority_weight: float = 0.9,
period: int = 20,
max_gradient_norm: float = None,
discount: float = 0.99,
Expand Down Expand Up @@ -258,6 +260,8 @@ def __init__(
samples_per_insert=samples_per_insert,
n_step=n_step,
sequence_length=sequence_length,
importance_sampling_exponent=importance_sampling_exponent,
max_priority_weight=max_priority_weight,
period=period,
max_gradient_norm=max_gradient_norm,
checkpoint=checkpoint,
Expand Down
Loading