From dde78219dc66f0598403b7dae6414316921c876e Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 19 Jan 2022 16:30:37 +0200 Subject: [PATCH 01/56] First commit. --- .../recurrent/decentralised/run_madqn.py | 1 + .../feedforward/decentralised/run_madqn.py | 9 +- .../feedforward/decentralised/run_qmix.py | 106 -- .../smac/feedforward/decentralised/run_vdn.py | 104 -- .../decentralised/run_vdn_record.py | 107 -- .../smac/feedforward/decentralised/test.py | 19 + .../tf/architectures/decentralised.py | 23 +- mava/systems/tf/madqn/__init__.py | 11 +- mava/systems/tf/madqn/builder.py | 564 ++++++--- mava/systems/tf/madqn/execution.py | 606 ++++----- mava/systems/tf/madqn/networks.py | 165 ++- mava/systems/tf/madqn/system.py | 720 ++++++----- mava/systems/tf/madqn/training.py | 1101 ++++++++--------- mava/utils/training_utils.py | 11 + mava/wrappers/__init__.py | 2 + mava/wrappers/smac.py | 331 +++++ 16 files changed, 1967 insertions(+), 1913 deletions(-) delete mode 100644 examples/smac/feedforward/decentralised/run_qmix.py delete mode 100644 examples/smac/feedforward/decentralised/run_vdn.py delete mode 100644 examples/smac/feedforward/decentralised/run_vdn_record.py create mode 100644 examples/smac/feedforward/decentralised/test.py create mode 100644 mava/wrappers/smac.py diff --git a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py index 55e19c5ed..1a9381997 100644 --- a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py +++ b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py @@ -91,6 +91,7 @@ def main(_: Any) -> None: checkpoint_subpath=checkpoint_dir, trainer_fn=madqn.training.MADQNRecurrentTrainer, executor_fn=madqn.execution.MADQNRecurrentExecutor, + max_replay_size=5000, batch_size=32, ).build() diff --git a/examples/smac/feedforward/decentralised/run_madqn.py b/examples/smac/feedforward/decentralised/run_madqn.py index 317a1298d..5f869514d 100644 --- a/examples/smac/feedforward/decentralised/run_madqn.py +++ b/examples/smac/feedforward/decentralised/run_madqn.py @@ -54,7 +54,7 @@ def main(_: Any) -> None: # Networks. network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, policy_networks_layer_sizes=[64, 64] + madqn.make_default_networks, value_networks_layer_sizes=[64, 64] ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -80,15 +80,14 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationTimestepScheduler( epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 ), checkpoint_subpath=checkpoint_dir, - batch_size=512, - executor_variable_update_period=100, + batch_size=256, + executor_variable_update_period=1000, target_update_period=200, - max_gradient_norm=10.0, + max_gradient_norm=20.0, ).build() # launch diff --git a/examples/smac/feedforward/decentralised/run_qmix.py b/examples/smac/feedforward/decentralised/run_qmix.py deleted file mode 100644 index 2ba3d2b0c..000000000 --- a/examples/smac/feedforward/decentralised/run_qmix.py +++ /dev/null @@ -1,106 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -from datetime import datetime -from typing import Any - -import launchpad as lp -import sonnet as snt -from absl import app, flags - -from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler -from mava.systems.tf import qmix -from mava.utils import lp_utils -from mava.utils.environments import pettingzoo_utils -from mava.utils.loggers import logger_utils - -FLAGS = flags.FLAGS -flags.DEFINE_string( - "map_name", - "3m", - "Starcraft 2 micromanagement map name (str).", -) - -flags.DEFINE_string( - "mava_id", - str(datetime.now()), - "Experiment identifier that can be used to continue experiments.", -) -flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") - - -def main(_: Any) -> None: - """Example running QMIX on SMAC environments.""" - # Environment. - environment_factory = functools.partial( - pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name - ) - - # Networks. - network_factory = lp_utils.partial_kwargs(qmix.make_default_networks) - - # Checkpointer appends "Checkpoints" to checkpoint_dir. - checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" - - # Log every [log_every] seconds. - log_every = 10 - logger_factory = functools.partial( - logger_utils.make_logger, - directory=FLAGS.base_dir, - to_terminal=True, - to_tensorboard=True, - time_stamp=FLAGS.mava_id, - time_delta=log_every, - ) - - # distributed program - program = qmix.QMIX( - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - num_executors=1, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 - ), - max_replay_size=1000000, - optimizer=snt.optimizers.RMSProp( - learning_rate=0.0005, epsilon=0.00001, decay=0.99 - ), - checkpoint_subpath=checkpoint_dir, - batch_size=512, - qmix_hidden_dim=32, - num_hypernet_layers=1, - hypernet_hidden_dim=32, - executor_variable_update_period=100, - target_update_period=200, - max_gradient_norm=10.0, - ).build() - - # launch - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] - ) - lp.launch( - program, - lp.LaunchType.LOCAL_MULTI_PROCESSING, - terminal="current_terminal", - local_resources=local_resources, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/examples/smac/feedforward/decentralised/run_vdn.py b/examples/smac/feedforward/decentralised/run_vdn.py deleted file mode 100644 index 3b087e7e8..000000000 --- a/examples/smac/feedforward/decentralised/run_vdn.py +++ /dev/null @@ -1,104 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -from datetime import datetime -from typing import Any - -import launchpad as lp -import sonnet as snt -from absl import app, flags - -from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler -from mava.systems.tf import vdn -from mava.utils import lp_utils -from mava.utils.environments import pettingzoo_utils -from mava.utils.loggers import logger_utils - -FLAGS = flags.FLAGS -flags.DEFINE_string( - "map_name", - "3m", - "Starcraft 2 micromanagement map name (str).", -) - -flags.DEFINE_string( - "mava_id", - str(datetime.now()), - "Experiment identifier that can be used to continue experiments.", -) -flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") - - -def main(_: Any) -> None: - """Example running VDN on multi-agent Starcraft 2 (SMAC) environment.""" - # environment - environment_factory = functools.partial( - pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name - ) - - # Networks. - network_factory = lp_utils.partial_kwargs( - vdn.make_default_networks, policy_networks_layer_sizes=[64, 64] - ) - - # Checkpointer appends "Checkpoints" to checkpoint_dir - checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" - - # Log every [log_every] seconds. - log_every = 10 - logger_factory = functools.partial( - logger_utils.make_logger, - directory=FLAGS.base_dir, - to_terminal=True, - to_tensorboard=True, - time_stamp=FLAGS.mava_id, - time_delta=log_every, - ) - - # distributed program - program = vdn.VDN( - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - num_executors=1, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 - ), - optimizer=snt.optimizers.RMSProp( - learning_rate=0.0005, epsilon=0.00001, decay=0.99 - ), - checkpoint_subpath=checkpoint_dir, - batch_size=512, - executor_variable_update_period=100, - target_update_period=200, - max_gradient_norm=10.0, - ).build() - - # launch - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] - ) - lp.launch( - program, - lp.LaunchType.LOCAL_MULTI_PROCESSING, - terminal="current_terminal", - local_resources=local_resources, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/examples/smac/feedforward/decentralised/run_vdn_record.py b/examples/smac/feedforward/decentralised/run_vdn_record.py deleted file mode 100644 index 90a16507e..000000000 --- a/examples/smac/feedforward/decentralised/run_vdn_record.py +++ /dev/null @@ -1,107 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -from datetime import datetime -from typing import Any - -import launchpad as lp -import sonnet as snt -from absl import app, flags - -from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler -from mava.systems.tf import vdn -from mava.utils import lp_utils -from mava.utils.environments import pettingzoo_utils -from mava.utils.loggers import logger_utils -from mava.wrappers.environment_loop_wrappers import MonitorParallelEnvironmentLoop - -FLAGS = flags.FLAGS -flags.DEFINE_string( - "map_name", - "3m", - "Starcraft 2 micromanagement map name (str).", -) - -flags.DEFINE_string( - "mava_id", - str(datetime.now()), - "Experiment identifier that can be used to continue experiments.", -) -flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") - - -def main(_: Any) -> None: - """Example running VDN on SMAC,while recording agents.""" - # environment - environment_factory = functools.partial( - pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name - ) - - # Networks. - network_factory = lp_utils.partial_kwargs( - vdn.make_default_networks, policy_networks_layer_sizes=[64, 64] - ) - - # Checkpointer appends "Checkpoints" to checkpoint_dir - checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" - - # Log every [log_every] seconds. - log_every = 10 - logger_factory = functools.partial( - logger_utils.make_logger, - directory=FLAGS.base_dir, - to_terminal=True, - to_tensorboard=True, - time_stamp=FLAGS.mava_id, - time_delta=log_every, - ) - - # distributed program - program = vdn.VDN( - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - num_executors=1, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 - ), - optimizer=snt.optimizers.RMSProp( - learning_rate=0.0005, epsilon=0.00001, decay=0.99 - ), - checkpoint_subpath=checkpoint_dir, - batch_size=512, - executor_variable_update_period=100, - target_update_period=200, - max_gradient_norm=10.0, - eval_loop_fn=MonitorParallelEnvironmentLoop, - eval_loop_fn_kwargs={"path": checkpoint_dir, "record_every": 100}, - ).build() - - # launch - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] - ) - lp.launch( - program, - lp.LaunchType.LOCAL_MULTI_PROCESSING, - terminal="current_terminal", - local_resources=local_resources, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/examples/smac/feedforward/decentralised/test.py b/examples/smac/feedforward/decentralised/test.py new file mode 100644 index 000000000..909f7a37f --- /dev/null +++ b/examples/smac/feedforward/decentralised/test.py @@ -0,0 +1,19 @@ +from smac.env import StarCraft2Env +from mava.wrappers import SMACWrapper + +import numpy as np + +env = StarCraft2Env(map_name="3m") + +env = SMACWrapper(env) + +spec = env.action_spec() +spec = env.observation_spec() + +res = env.reset() + +actions = {"agent_0": 1, "agent_1": 1, "agent_2": 1} + +res = env.step(actions) + +print("Done") \ No newline at end of file diff --git a/mava/components/tf/architectures/decentralised.py b/mava/components/tf/architectures/decentralised.py index 04bbfa019..020ebd0a8 100644 --- a/mava/components/tf/architectures/decentralised.py +++ b/mava/components/tf/architectures/decentralised.py @@ -38,6 +38,8 @@ def __init__( self, environment_spec: mava_specs.MAEnvironmentSpec, value_networks: Dict[str, snt.Module], + action_selectors: Dict[str, snt.Module], + observation_networks: Dict[str, snt.Module], agent_net_keys: Dict[str, str], ): self._env_spec = environment_spec @@ -47,6 +49,8 @@ def __init__( self._agent_type_specs = self._env_spec.get_agent_type_specs() self._value_networks = value_networks + self._action_selectors = action_selectors + self._observation_networks = observation_networks self._agent_net_keys = agent_net_keys self._n_agents = len(self._agents) @@ -55,6 +59,7 @@ def __init__( def _create_target_networks(self) -> None: # create target behaviour networks self._target_value_networks = copy.deepcopy(self._value_networks) + self._target_observation_networks = copy.deepcopy(self._observation_networks) def _get_actor_specs(self) -> Dict[str, OLT]: actor_obs_specs = {} @@ -70,6 +75,9 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: actor_networks: Dict[str, Dict[str, snt.Module]] = { "values": {}, "target_values": {}, + "observations": {}, + "target_observations": {}, + "selectors": {} } # get actor specs @@ -79,16 +87,19 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: for agent_key in self._agents: agent_net_key = self._agent_net_keys[agent_key] obs_spec = actor_obs_specs[agent_key] - # Create variables for value and policy networks. - tf2_utils.create_variables(self._value_networks[agent_net_key], [obs_spec]) + # Create variables for observation and value networks. + embed = tf2_utils.create_variables(self._observation_networks[agent_net_key], [obs_spec]) + tf2_utils.create_variables(self._value_networks[agent_net_key], [embed]) - # create target value network variables - tf2_utils.create_variables( - self._target_value_networks[agent_net_key], [obs_spec] - ) + # Create target value and observation network variables + embed = tf2_utils.create_variables(self._target_observation_networks[agent_net_key], [obs_spec]) + tf2_utils.create_variables(self._target_value_networks[agent_net_key], [embed]) actor_networks["values"] = self._value_networks actor_networks["target_values"] = self._target_value_networks + actor_networks["selectors"] = self._action_selectors + actor_networks["observations"] = self._observation_networks + actor_networks["target_observations"] = self._target_observation_networks return actor_networks diff --git a/mava/systems/tf/madqn/__init__.py b/mava/systems/tf/madqn/__init__.py index 029bfd943..f156e6be4 100644 --- a/mava/systems/tf/madqn/__init__.py +++ b/mava/systems/tf/madqn/__init__.py @@ -13,6 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Implementations of a MADDPG agent.""" + +from mava.systems.tf.madqn.execution import ( + MADQNFeedForwardExecutor, + MADQNRecurrentExecutor, +) from mava.systems.tf.madqn.networks import make_default_networks from mava.systems.tf.madqn.system import MADQN -from mava.systems.tf.madqn.training import MADQNRecurrentTrainer, MADQNTrainer +from mava.systems.tf.madqn.training import ( + MADQNRecurrentTrainer, + MADQNTrainer, +) diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index cc54ed8ff..71747811b 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -13,209 +13,275 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MADQN system builder implementation.""" +"""MADQN scaled system builder implementation.""" +import copy import dataclasses from typing import Any, Dict, Iterator, List, Optional, Type, Union -import numpy as np import reverb import sonnet as snt +import tensorflow as tf from acme import datasets -from acme.tf import variable_utils -from acme.utils import counting +from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers +from dm_env import specs as dm_specs from mava import adders, core, specs, types from mava.adders import reverb as reverb_adders -from mava.components.tf.modules.communication import BaseCommunicationModule +from mava.systems.tf import executors, variable_utils +from mava.systems.tf.madqn import training +from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor +from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.utils.sort_utils import sort_str_num +from mava.wrappers import NetworkStatisticsActorCritic, ScaledDetailedTrainerStatistics +from mava.utils.builder_utils import initialize_epsilon_schedulers from mava.components.tf.modules.exploration.exploration_scheduling import ( BaseExplorationScheduler, BaseExplorationTimestepScheduler, ConstantScheduler, ) -from mava.components.tf.modules.stabilising import FingerPrintStabalisation -from mava.systems.tf import executors -from mava.systems.tf.madqn import execution, training -from mava.utils.builder_utils import initialize_epsilon_schedulers -from mava.wrappers import DetailedTrainerStatistics + +BoundedArray = dm_specs.BoundedArray +DiscreteArray = dm_specs.DiscreteArray @dataclasses.dataclass class MADQNConfig: - """Configuration options for the MADQN system. + """Configuration options for the MADDPG system. Args: environment_spec: description of the action and observation spaces etc. for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - target_update_period: number of learner steps to perform before updating - the target networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + num_executors: number of parallel executors to use. + agent_net_keys: specifies what network each agent uses. + trainer_networks: networks each trainer trains on. + table_network_config: Networks each table (trainer) expects. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + net_keys_to_ids: mapping from net_key to network id. + unique_net_keys: list of unique net_keys. + checkpoint_minute_interval: The number of minutes to wait between + checkpoints. + discount: discount to use for TD updates. + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_averaging: whether to use polyak averaging for target network updates. + target_update_period: number of steps before target networks are updated. + target_update_rate: update rate when using averaging. executor_variable_update_period: the rate at which executors sync their paramters with the trainer. - max_gradient_norm: value to specify the maximum clipping value for the gradient - norm during optimization. min_replay_size: minimum replay size before updating. max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. - prefetch_size: size to prefetch from replay. - batch_size: batch size for updates. n_step: number of steps to include prior to boostrapping. sequence_length: recurrent sequence rollout length. period: consecutive starting points for overlapping rollouts across a sequence. - discount: discount to use for TD updates. + max_gradient_norm: value to specify the maximum clipping value for the gradient + norm during optimization. + logger: logger to use. checkpoint: boolean to indicate whether to checkpoint models. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - optimizer: type of optimizer to use for updating the parameters of models. - replay_table_name: string indicating what name to give the replay table. checkpoint_subpath: subdirectory specifying where to store checkpoints. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. + termination_condition: An optional terminal condition can be provided + that stops the program once the condition is satisfied. Available options + include specifying maximum values for trainer_steps, trainer_walltime, + evaluator_steps, evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. + learning_rate_scheduler_fn: dict with two functions/classes (one for the + policy and one for the critic optimizer), that takes in a trainer + step t and returns the current learning rate, + e.g. {"policy": policy_lr_schedule ,"critic": critic_lr_schedule}. + See + examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py + for an example. + evaluator_interval: An optional condition that is used to + evaluate/test system performance after [evaluator_interval] + condition has been met. """ environment_spec: specs.MAEnvironmentSpec + optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] + num_executors: int agent_net_keys: Dict[str, str] - target_update_period: int - executor_variable_update_period: int - max_gradient_norm: Optional[float] - min_replay_size: int - max_replay_size: int - samples_per_insert: Optional[float] - prefetch_size: int - batch_size: int - n_step: int - max_priority_weight: float - importance_sampling_exponent: Optional[float] - sequence_length: int - period: int - discount: float - checkpoint: bool + trainer_networks: Dict[str, List] + table_network_config: Dict[str, List] + network_sampling_setup: List + net_keys_to_ids: Dict[str, int] + unique_net_keys: List[str] checkpoint_minute_interval: int - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] - replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + discount: float = 0.99 + batch_size: int = 256 + prefetch_size: int = 4 + target_averaging: bool = False + target_update_period: int = 100 + target_update_rate: Optional[float] = None + executor_variable_update_period: int = 1000 + min_replay_size: int = 1000 + max_replay_size: int = 1000000 + samples_per_insert: Optional[float] = 32.0 + n_step: int = 5 + sequence_length: int = 20 + period: int = 20 + # bootstrap_n: int = 10 + max_gradient_norm: Optional[float] = None + logger: loggers.Logger = None + counter: counting.Counter = None + checkpoint: bool = True checkpoint_subpath: str = "~/mava/" + termination_condition: Optional[Dict[str, int]] = None evaluator_interval: Optional[dict] = None learning_rate_scheduler_fn: Optional[Any] = None class MADQNBuilder: - """Builder for MADQN which constructs individual components of the system.""" + """Builder for scaled MADDPG which constructs individual components of the + system.""" def __init__( self, config: MADQNConfig, - trainer_fn: Type[training.MADQNTrainer] = training.MADQNTrainer, - executor_fn: Type[core.Executor] = execution.MADQNFeedForwardExecutor, + trainer_fn: Union[ + Type[training.MADQNTrainer], + Type[training.MADQNRecurrentTrainer], + ] = training.MADQNTrainer, + executor_fn: Type[core.Executor] = MADQNFeedForwardExecutor, extra_specs: Dict[str, Any] = {}, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, ): """Initialise the system. - Args: - config (MADQNConfig): system configuration specifying hyperparameters and + config: system configuration specifying hyperparameters and additional information for constructing the system. - trainer_fn (Type[training.MADQNTrainer], optional): Trainer function, of a - correpsonding type to work with the selected system architecture. - Defaults to training.MADQNTrainer. - executor_fn (Type[core.Executor], optional): Executor function, of a - corresponding type to work with the selected system architecture. - Defaults to execution.MADQNFeedForwardExecutor. - extra_specs (Dict[str, Any], optional): defines the specifications of extra - information used by the system. Defaults to {}. - replay_stabilisation_fn : optional function to stabilise experience replay. + trainer_fn: Trainer function, of a correpsonding type to work with + the selected system architecture. + executor_fn: Executor function, of a corresponding type to work with + the selected system architecture. + extra_specs: defines the specifications of extra + information used by the system. """ self._config = config self._extra_specs = extra_specs - self._agents = self._config.environment_spec.get_agent_ids() self._agent_types = self._config.environment_spec.get_agent_types() self._trainer_fn = trainer_fn self._executor_fn = executor_fn - self._replay_stabiliser_fn = replay_stabilisation_fn + + + def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: + if type(spec) is not dict: + return spec + + agents = sort_str_num(self._config.agent_net_keys.keys())[:num_networks] + converted_spec: Dict[str, Any] = {} + if agents[0] in spec.keys(): + for agent in agents: + converted_spec[agent] = spec[agent] + else: + # For the extras + for key in spec.keys(): + converted_spec[key] = self.covert_specs(spec[key], num_networks) + return converted_spec def make_replay_tables( self, environment_spec: specs.MAEnvironmentSpec, ) -> List[reverb.Table]: - """Create tables to insert data into. - + """ "Create tables to insert data into. Args: - environment_spec (specs.MAEnvironmentSpec): description of the action and + environment_spec: description of the action and observation spaces etc. for each agent in the system. - Raises: NotImplementedError: unknown executor type. - Returns: - List[reverb.Table]: a list of data tables for inserting data. + a list of data tables for inserting data. """ - # Select adder if issubclass(self._executor_fn, executors.FeedForwardExecutor): - # Check if we should use fingerprints - if self._replay_stabiliser_fn is not None: - self._extra_specs.update({"fingerprint": np.array([1.0, 1.0])}) - adder_sig = reverb_adders.ParallelNStepTransitionAdder.signature( - environment_spec, self._extra_specs - ) + + def adder_sig_fn( + env_spec: specs.MAEnvironmentSpec, extra_specs: Dict[str, Any] + ) -> Any: + return reverb_adders.ParallelNStepTransitionAdder.signature( + env_spec, extra_specs + ) + elif issubclass(self._executor_fn, executors.RecurrentExecutor): - adder_sig = reverb_adders.ParallelSequenceAdder.signature( - environment_spec, self._config.sequence_length, self._extra_specs - ) + + def adder_sig_fn( + env_spec: specs.MAEnvironmentSpec, extra_specs: Dict[str, Any] + ) -> Any: + return reverb_adders.ParallelSequenceAdder.signature( + env_spec, self._config.sequence_length, extra_specs + ) + else: raise NotImplementedError("Unknown executor type: ", self._executor_fn) if self._config.samples_per_insert is None: # We will take a samples_per_insert ratio of None to mean that there is # no limit, i.e. this only implies a min size limit. - limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size) + def limiter_fn() -> reverb.rate_limiters: + return reverb.rate_limiters.MinSize(self._config.min_replay_size) else: # Create enough of an error buffer to give a 10% tolerance in rate. samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert error_buffer = self._config.min_replay_size * samples_per_insert_tolerance - limiter = reverb.rate_limiters.SampleToInsertRatio( - min_size_to_sample=self._config.min_replay_size, - samples_per_insert=self._config.samples_per_insert, - error_buffer=error_buffer, - ) - # Maybe use prioritized sampling. - if self._config.importance_sampling_exponent is not None: - sampler = reverb.selectors.Prioritized( - self._config.importance_sampling_exponent + def limiter_fn() -> reverb.rate_limiters: + return reverb.rate_limiters.SampleToInsertRatio( + min_size_to_sample=self._config.min_replay_size, + samples_per_insert=self._config.samples_per_insert, + error_buffer=error_buffer, + ) + + # Create table per trainer + replay_tables = [] + for table_key in self._config.table_network_config.keys(): + # TODO (dries): Clean the below coverter code up. + # Convert a Mava spec + num_networks = len(self._config.table_network_config[table_key]) + env_spec = copy.deepcopy(environment_spec) + env_spec._specs = self.covert_specs(env_spec._specs, num_networks) + + env_spec._keys = list(sort_str_num(env_spec._specs.keys())) + if env_spec.extra_specs is not None: + env_spec.extra_specs = self.covert_specs( + env_spec.extra_specs, num_networks + ) + extra_specs = self.covert_specs( + self._extra_specs, + num_networks, ) - else: - sampler = reverb.selectors.Uniform() - - replay_table = reverb.Table( - name=self._config.replay_table_name, - sampler=sampler, - remover=reverb.selectors.Fifo(), - max_size=self._config.max_replay_size, - rate_limiter=limiter, - signature=adder_sig, - ) - return [replay_table] + replay_tables.append( + reverb.Table( + name=table_key, + sampler=reverb.selectors.Uniform(), + remover=reverb.selectors.Fifo(), + max_size=self._config.max_replay_size, + rate_limiter=limiter_fn(), + signature=adder_sig_fn(env_spec, extra_specs), + ) + ) + return replay_tables def make_dataset_iterator( - self, replay_client: reverb.Client + self, + replay_client: reverb.Client, + table_name: str, ) -> Iterator[reverb.ReplaySample]: """Create a dataset iterator to use for training/updating the system. - Args: - replay_client (reverb.Client): Reverb Client which points to the + replay_client: Reverb Client which points to the replay server. - Returns: [type]: dataset iterator. - Yields: - Iterator[reverb.ReplaySample]: data samples from the dataset. + data samples from the dataset. """ sequence_length = ( @@ -224,8 +290,9 @@ def make_dataset_iterator( else None ) + """Create a dataset iterator to use for learning/updating the system.""" dataset = datasets.make_reverb_dataset( - table=self._config.replay_table_name, + table=table_name, server_address=replay_client.server_address, batch_size=self._config.batch_size, prefetch_size=self._config.prefetch_size, @@ -234,45 +301,104 @@ def make_dataset_iterator( return iter(dataset) def make_adder( - self, replay_client: reverb.Client + self, + replay_client: reverb.Client, ) -> Optional[adders.ParallelAdder]: """Create an adder which records data generated by the executor/environment. - Args: - replay_client (reverb.Client): Reverb Client which points to the + replay_client: Reverb Client which points to the replay server. - Raises: NotImplementedError: unknown executor type. - Returns: - Optional[adders.ParallelAdder]: adder which sends data to a replay buffer. + adder which sends data to a replay buffer. """ + # Create custom priority functons for the adder + priority_fns = { + table_key: lambda x: 1.0 + for table_key in self._config.table_network_config.keys() + } # Select adder if issubclass(self._executor_fn, executors.FeedForwardExecutor): adder = reverb_adders.ParallelNStepTransitionAdder( - priority_fns=None, + priority_fns=priority_fns, client=replay_client, + net_ids_to_keys=self._config.unique_net_keys, n_step=self._config.n_step, + table_network_config=self._config.table_network_config, discount=self._config.discount, ) elif issubclass(self._executor_fn, executors.RecurrentExecutor): adder = reverb_adders.ParallelSequenceAdder( - priority_fns=None, + priority_fns=priority_fns, client=replay_client, + net_ids_to_keys=self._config.unique_net_keys, sequence_length=self._config.sequence_length, + table_network_config=self._config.table_network_config, period=self._config.period, ) else: raise NotImplementedError("Unknown executor type: ", self._executor_fn) + print("################3", adder) return adder + def create_counter_variables( + self, variables: Dict[str, tf.Variable] + ) -> Dict[str, tf.Variable]: + """Create counter variables. + Args: + variables: dictionary with variable_source + variables in. + Returns: + variables: dictionary with variable_source + variables in. + """ + variables["trainer_steps"] = tf.Variable(0, dtype=tf.int32) + variables["trainer_walltime"] = tf.Variable(0, dtype=tf.float32) + variables["evaluator_steps"] = tf.Variable(0, dtype=tf.int32) + variables["evaluator_episodes"] = tf.Variable(0, dtype=tf.int32) + variables["executor_episodes"] = tf.Variable(0, dtype=tf.int32) + variables["executor_steps"] = tf.Variable(0, dtype=tf.int32) + return variables + + def make_variable_server( + self, + networks: Dict[str, Dict[str, snt.Module]], + ) -> MavaVariableSource: + """Create the variable server. + Args: + networks: dictionary with the + system's networks in. + Returns: + variable_source: A Mava variable source object. + """ + # Create variables + variables = {} + # Network variables + for net_type_key in networks.keys(): + for net_key in networks[net_type_key].keys(): + # Ensure obs and target networks are sonnet modules + variables[f"{net_key}_{net_type_key}"] = tf2_utils.to_sonnet_module( + networks[net_type_key][net_key] + ).variables + + variables = self.create_counter_variables(variables) + + # Create variable source + variable_source = MavaVariableSource( + variables, + self._config.checkpoint, + self._config.checkpoint_subpath, + self._config.checkpoint_minute_interval, + self._config.termination_condition, + ) + return variable_source + def make_executor( self, - q_networks: Dict[str, snt.Module], - action_selectors: Dict[str, Any], + networks: Dict[str, snt.Module], exploration_schedules: Dict[ str, Union[ @@ -282,49 +408,53 @@ def make_executor( ], ], adder: Optional[adders.ParallelAdder] = None, - variable_source: Optional[core.VariableSource] = None, - trainer: Optional[training.MADQNTrainer] = None, - communication_module: Optional[BaseCommunicationModule] = None, + variable_source: Optional[MavaVariableSource] = None, evaluator: bool = False, - seed: Optional[int] = None, ) -> core.Executor: """Create an executor instance. - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - exploration_schedules: epsilon decay scheduler per agent. - adder (Optional[adders.ParallelAdder], optional): adder to send data to + networks: dictionary with the system's networks in. + policy_networks: policy networks for each agent in + the system. + adder: adder to send data to a replay buffer. Defaults to None. - variable_source (Optional[core.VariableSource], optional): variables server. + variable_source: variables server. Defaults to None. - trainer (Optional[training.MADQNRecurrentCommTrainer], optional): - system trainer. Defaults to None. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. Defaults to None. - evaluator (bool, optional): boolean indicator if the executor is used for - for evaluation only. Defaults to False. - seed: seed for reproducible sampling. - + evaluator: boolean indicator if the executor is used for + for evaluation only. Returns: - core.Executor: system executor, a collection of agents making up the part + system executor, a collection of agents making up the part of the system generating data by interacting the environment. """ + # Create variables + variables = {} + get_keys = [] + for net_type_key in ["observations", "values"]: + for net_key in networks[net_type_key].keys(): + var_key = f"{net_key}_{net_type_key}" + variables[var_key] = networks[net_type_key][net_key].variables + get_keys.append(var_key) + variables = self.create_counter_variables(variables) + + count_names = [ + "trainer_steps", + "trainer_walltime", + "evaluator_steps", + "evaluator_episodes", + "executor_episodes", + "executor_steps", + ] + get_keys.extend(count_names) + counts = {name: variables[name] for name in count_names} - agent_net_keys = self._config.agent_net_keys evaluator_interval = self._config.evaluator_interval if evaluator else None variable_client = None if variable_source: - # Create policy variables - variables = { - net_key: q_networks[net_key].variables for net_key in q_networks.keys() - } # Get new policy variables variable_client = variable_utils.VariableClient( client=variable_source, - variables={"q_network": variables}, + variables=variables, + get_keys=get_keys, # If we are using evaluator_intervals, # we should always get the latest variables. update_period=0 @@ -334,27 +464,26 @@ def make_executor( # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. - variable_client.update_and_wait() - - # Check if we should use fingerprints - fingerprint = True if self._replay_stabiliser_fn is not None else False + variable_client.get_and_wait() # Pass scheduler and initialize action selectors action_selectors_with_scheduler = initialize_epsilon_schedulers( - exploration_schedules, action_selectors, agent_net_keys, seed=seed + exploration_schedules, networks["selectors"], self._config.agent_net_keys ) - # Create the executor which coordinates the actors. + # Create the actor which defines how we take actions. return self._executor_fn( - q_networks=q_networks, + observation_networks=networks["observations"], + value_networks=networks["values"], action_selectors=action_selectors_with_scheduler, - agent_net_keys=agent_net_keys, + counts=counts, + net_keys_to_ids=self._config.net_keys_to_ids, + agent_specs=self._config.environment_spec.get_agent_specs(), + agent_net_keys=self._config.agent_net_keys, + network_sampling_setup=self._config.network_sampling_setup, variable_client=variable_client, adder=adder, - trainer=trainer, - communication_module=communication_module, evaluator=evaluator, - fingerprint=fingerprint, interval=evaluator_interval, ) @@ -362,60 +491,103 @@ def make_trainer( self, networks: Dict[str, Dict[str, snt.Module]], dataset: Iterator[reverb.ReplaySample], - counter: Optional[counting.Counter] = None, + variable_source: MavaVariableSource, + trainer_networks: List[Any], + trainer_table_entry: List[Any], logger: Optional[types.NestedLogger] = None, - communication_module: Optional[BaseCommunicationModule] = None, - replay_client: Optional[reverb.TFClient] = None, ) -> core.Trainer: """Create a trainer instance. - Args: - networks (Dict[str, Dict[str, snt.Module]]): system networks. - dataset (Iterator[reverb.ReplaySample]): dataset iterator to feed data to + networks: system networks. + dataset: dataset iterator to feed data to the trainer networks. - counter (Optional[counting.Counter], optional): a Counter which allows for - recording of counts, e.g. trainer steps. Defaults to None. - logger (Optional[types.NestedLogger], optional): Logger object for logging - metadata.. Defaults to None. - communication_module (BaseCommunicationModule): module to enable - agent communication. Defaults to None. - + variable_source: Source with variables in. + trainer_networks: Set of unique network keys to train on.. + trainer_table_entry: List of networks per agent to train on. + logger: Logger object for logging metadata. Returns: - core.Trainer: system trainer, that uses the collected data from the + system trainer, that uses the collected data from the executors to update the parameters of the agent networks in the system. """ + # This assumes agents are sort_str_num in the other methods + agent_types = self._agent_types + max_gradient_norm = self._config.max_gradient_norm + discount = self._config.discount + target_update_period = self._config.target_update_period + target_averaging = self._config.target_averaging + target_update_rate = self._config.target_update_rate + + # Create variable client + variables = {} + set_keys = [] + get_keys = [] + # TODO (dries): Only add the networks this trainer is working with. + # Not all of them. + for net_type_key in ["observations", "values"]: + for net_key in networks[net_type_key].keys(): + variables[f"{net_key}_{net_type_key}"] = networks[net_type_key][ + net_key + ].variables + if net_key in set(trainer_networks): + set_keys.append(f"{net_key}_{net_type_key}") + else: + get_keys.append(f"{net_key}_{net_type_key}") + + variables = self.create_counter_variables(variables) + + count_names = [ + "trainer_steps", + "trainer_walltime", + "evaluator_steps", + "evaluator_episodes", + "executor_episodes", + "executor_steps", + ] + get_keys.extend(count_names) + counts = {name: variables[name] for name in count_names} + + variable_client = variable_utils.VariableClient( + client=variable_source, + variables=variables, + get_keys=get_keys, + set_keys=set_keys, + update_period=10, + ) - q_networks = networks["values"] - target_q_networks = networks["target_values"] - - agents = self._config.environment_spec.get_agent_ids() - agent_types = self._config.environment_spec.get_agent_types() - - # Check if we should use fingerprints - fingerprint = True if self._replay_stabiliser_fn is not None else False + # Get all the initial variables + variable_client.get_all_and_wait() + + # Convert network keys for the trainer. + trainer_agents = self._agents[: len(trainer_table_entry)] + trainer_agent_net_keys = { + agent: trainer_table_entry[a_i] for a_i, agent in enumerate(trainer_agents) + } + trainer_config: Dict[str, Any] = { + "agents": trainer_agents, + "agent_types": agent_types, + "value_networks": networks["values"], + "observation_networks": networks["observations"], + "target_value_networks": networks["target_values"], + "target_observation_networks": networks["target_observations"], + "agent_net_keys": trainer_agent_net_keys, + "optimizer": self._config.optimizer, + "max_gradient_norm": max_gradient_norm, + "discount": discount, + "target_averaging": target_averaging, + "target_update_period": target_update_period, + "target_update_rate": target_update_rate, + "variable_client": variable_client, + "dataset": dataset, + "counts": counts, + "logger": logger, + "learning_rate_scheduler_fn": self._config.learning_rate_scheduler_fn, + } # The learner updates the parameters (and initializes them). - trainer = self._trainer_fn( - agents=agents, - agent_types=agent_types, - discount=self._config.discount, - q_networks=q_networks, - target_q_networks=target_q_networks, - agent_net_keys=self._config.agent_net_keys, - optimizer=self._config.optimizer, - target_update_period=self._config.target_update_period, - max_gradient_norm=self._config.max_gradient_norm, - communication_module=communication_module, - dataset=dataset, - counter=counter, - fingerprint=fingerprint, - logger=logger, - checkpoint=self._config.checkpoint, - checkpoint_subpath=self._config.checkpoint_subpath, - checkpoint_minute_interval=self._config.checkpoint_minute_interval, - learning_rate_scheduler_fn=self._config.learning_rate_scheduler_fn, - ) + trainer = self._trainer_fn(**trainer_config) - trainer = DetailedTrainerStatistics(trainer) # type:ignore + trainer = ScaledDetailedTrainerStatistics( # type: ignore + trainer, metrics=["value_loss"] + ) return trainer diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 1efd8a2c4..ff27b5883 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -13,32 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. - """MADQN system executor implementation.""" - -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import dm_env import numpy as np import sonnet as snt import tensorflow as tf +import tensorflow_probability as tfp from acme import types +from acme.specs import EnvironmentSpec + +# Internal imports. from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils +from dm_env import specs from mava import adders -from mava.components.tf.modules.communication import BaseCommunicationModule +from mava import core +from mava.systems.tf import executors +from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num from mava.components.tf.modules.exploration.exploration_scheduling import ( BaseExplorationTimestepScheduler, ) -from mava.systems.tf.executors import ( - FeedForwardExecutor, - RecurrentCommExecutor, - RecurrentExecutor, -) -from mava.systems.tf.madqn.training import MADQNTrainer -from mava.types import OLT +Array = specs.Array +BoundedArray = specs.BoundedArray +DiscreteArray = specs.DiscreteArray +tfd = tfp.distributions class DQNExecutor: def __init__(self, action_selectors: Dict): @@ -46,7 +48,6 @@ def __init__(self, action_selectors: Dict): def _get_epsilon(self) -> Union[float, np.ndarray]: """Return epsilon. - Returns: epsilon values. """ @@ -81,7 +82,6 @@ def after_action_selection(self, time_t: int) -> None: def get_stats(self) -> Dict: """Return extra stats to log. - Returns: epsilon information. """ @@ -91,502 +91,412 @@ def get_stats(self) -> Dict: } -class MADQNFeedForwardExecutor(FeedForwardExecutor, DQNExecutor): - """A feed-forward executor. +class MADQNFeedForwardExecutor(executors.FeedForwardExecutor, DQNExecutor): + """A feed-forward executor for discrete actions. An executor based on a feed-forward policy for each agent in the system. """ def __init__( self, - q_networks: Dict[str, snt.Module], + observation_networks: Dict[str, snt.Module], + value_networks: Dict[str, snt.Module], action_selectors: Dict[str, snt.Module], - trainer: MADQNTrainer, + agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], + network_sampling_setup: List, + net_keys_to_ids: Dict[str, int], + evaluator: bool = False, adder: Optional[adders.ParallelAdder] = None, + counts: Optional[Dict[str, Any]] = None, variable_client: Optional[tf2_variable_utils.VariableClient] = None, - communication_module: Optional[BaseCommunicationModule] = None, - fingerprint: bool = False, - evaluator: bool = False, interval: Optional[dict] = None, ): - """Initialise the system executor + """Initialise the system executor Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - trainer (MADQNTrainer, optional): system trainer. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - adder (Optional[adders.ParallelAdder], optional): adder which sends data + policy_networks: policy networks for each agent in + the system. + agent_specs: agent observation and action + space specifications. + agent_net_keys: specifies what network each agent uses. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + net_keys_to_ids: Specifies a mapping from network keys to their integer id. + adder: adder which sends data to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): + counts: Count values used to record excutor episode and steps. + variable_client: client to copy weights from the trainer. Defaults to None. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. Defaults to None. - fingerprint (bool, optional): whether to use fingerprint stabilisation to - stabilise experience replay. Defaults to False. - evaluator (bool, optional): whether the executor will be used for - evaluation. Defaults to False. + evaluator: whether the executor will be used for + evaluation. interval: interval that evaluations are run at. """ # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._q_networks = q_networks - self._action_selectors = action_selectors - self._trainer = trainer - self._agent_net_keys = agent_net_keys - self._fingerprint = fingerprint + self._agent_specs = agent_specs + self._network_sampling_setup = network_sampling_setup + self._counts = counts + self._network_int_keys_extras: Dict[str, np.ndarray] = {} + self._net_keys_to_ids = net_keys_to_ids self._evaluator = evaluator self._interval = interval + self._observation_networks = observation_networks + self._action_selectors = action_selectors + self._value_networks = value_networks + self._agent_net_keys=agent_net_keys + self._adder=adder + self._variable_client=variable_client @tf.function def _policy( - self, - agent: str, + self, agent: str, observation: types.NestedTensor, - legal_actions: types.NestedTensor, - fingerprint: Optional[tf.Tensor] = None, + legal_actions: types.NestedTensor ) -> types.NestedTensor: """Agent specific policy function Args: - agent (str): agent id - observation (types.NestedTensor): observation tensor received from the + agent: agent id + observation: observation tensor received from the environment. - legal_actions (types.NestedTensor): actions allowed to be taken at the - current observation. - fingerprint (Optional[tf.Tensor], optional): policy fingerprints. Defaults - to None. + + Raises: + NotImplementedError: unknown action space Returns: types.NestedTensor: agent action """ + # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_observation = tf2_utils.add_batch_dim(observation) - batched_legals = tf2_utils.add_batch_dim(legal_actions) + batched_legal_actions = tf2_utils.add_batch_dim(legal_actions) # index network either on agent type or on agent id - agent_net_key = self._agent_net_keys[agent] + agent_key = self._agent_net_keys[agent] - # Compute the policy, conditioned on the observation and - # possibly the fingerprint. - if fingerprint is not None: - q_values = self._q_networks[agent_net_key](batched_observation, fingerprint) - else: - q_values = self._q_networks[agent_net_key](batched_observation) + # Pass through observation network + embed = self._observation_networks[agent_key](batched_observation) - action = self._action_selectors[agent]( - action_values=q_values, legal_actions_mask=batched_legals - ) + # Compute the action_values, conditioned on the observation. + action_values = self._value_networks[agent_key](embed) + + # Pass action values through action selector + action = self._action_selectors[agent](action_values, batched_legal_actions) return action def select_action( self, agent: str, observation: types.NestedArray - ) -> types.NestedArray: + ) -> Tuple[types.NestedArray, types.NestedArray]: """select an action for a single agent in the system Args: - agent (str): agent id - observation (types.NestedArray): observation tensor received from the - environment. + agent: agent id. + observation: observation tensor received + from the environment. Returns: - types.NestedArray: agent action + agent action and policy. """ + # Get the action from the policy, conditioned on the observation + action = self._policy(agent, observation.observation, observation.legal_actions) - if self._fingerprint: - trainer_step = self._trainer.get_trainer_steps() - fingerprint = tf.concat([self._get_epsilon(), trainer_step], axis=0) - fingerprint = tf.expand_dims(fingerprint, axis=0) - fingerprint = tf.cast(fingerprint, "float32") - else: - fingerprint = None - - action = self._policy( - agent=agent, - observation=observation.observation, - legal_actions=observation.legal_actions, - fingerprint=fingerprint, - ) - + # Return a numpy array with squeezed out batch dimension. action = tf2_utils.to_numpy_squeeze(action) return action + def select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + """select the actions for all agents in the system + + Args: + observations: agent observations from the + environment. + + Returns: + actions and policies for all agents in the system. + """ + + actions = {} + for agent, observation in observations.items(): + actions[agent] = self.select_action(agent, observation) + + return actions + def observe_first( self, timestep: dm_env.TimeStep, extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record first observed timestep from the environment - + """Record first observed timestep from the environment Args: - timestep (dm_env.TimeStep): data emitted by an environment at first step of + timestep: data emitted by an environment at first step of interaction. - extras (Dict[str, types.NestedArray], optional): possible extra information + extras: possible extra information to record during the first step. Defaults to {}. """ + if not self._adder: + return + + "Select new networks from the sampler at the start of each episode." + agents = sort_str_num(list(self._agent_net_keys.keys())) + self._network_int_keys_extras, self._agent_net_keys = sample_new_agent_keys( + agents, + self._network_sampling_setup, + self._net_keys_to_ids, + ) - if self._fingerprint and self._trainer is not None: - epsilon = self._get_epsilon() - trainer_step = self._trainer.get_trainer_steps() - fingerprint = np.array([epsilon, trainer_step]) - extras.update({"fingerprint": fingerprint}) - - if self._adder: - self._adder.add_first(timestep, extras) + extras["network_int_keys"] = self._network_int_keys_extras + + + self._adder.add_first(timestep, extras) def observe( self, - actions: Dict[str, types.NestedArray], + actions: Union[ + Dict[str, types.NestedArray], List[Dict[str, types.NestedArray]] + ], next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: """record observed timestep from the environment - Args: - actions (Dict[str, types.NestedArray]): system agents' actions. - next_timestep (dm_env.TimeStep): data emitted by an environment during + actions: system agents' actions. + next_timestep: data emitted by an environment during interaction. - next_extras (Dict[str, types.NestedArray], optional): possible extra + next_extras: possible extra information to record during the transition. Defaults to {}. """ + if not self._adder: + return - if self._fingerprint and self._trainer is not None: - trainer_step = self._trainer.get_trainer_steps() - epsilon = self._get_epsilon() - fingerprint = np.array([epsilon, trainer_step]) - next_extras.update({"fingerprint": fingerprint}) - - if self._adder: - self._adder.add(actions, next_timestep, next_extras) - - def select_actions( - self, observations: Dict[str, OLT] - ) -> Dict[str, types.NestedArray]: - """select the actions for all agents in the system - - Args: - observations (Dict[str, OLT]): transition object containing observations, - legal actions and terminals. - - Returns: - Dict[str, types.NestedArray]: actions for all agents in the system. - """ - actions = {} - for agent, observation in observations.items(): - actions[agent] = self.select_action(agent, observation) - - # Return a numpy array with squeezed out batch dimension. - return actions + next_extras["network_int_keys"] = self._network_int_keys_extras + # TODO (dries): Sort out this mypy issue. + self._adder.add(actions, next_timestep, next_extras) # type: ignore def update(self, wait: bool = False) -> None: - """update executor variables - - Args: - wait (bool, optional): whether to stall the executor's request for new - variables. Defaults to False. - """ - + """Update the policy variables.""" if self._variable_client: - self._variable_client.update(wait) + self._variable_client.get_async() -class MADQNRecurrentExecutor(RecurrentExecutor, DQNExecutor): - """A recurrent executor. - An executor based on a recurrent policy for each agent in the system +class MADQNRecurrentExecutor(executors.RecurrentExecutor, MADQNFeedForwardExecutor): + """A recurrent executor for MADQN. + + An executor based on a recurrent policy for each agent in the system. """ def __init__( self, - q_networks: Dict[str, snt.Module], + observation_networks :Dict[str, snt.Module], action_selectors: Dict[str, snt.Module], + value_networks: Dict[str, snt.Module], + agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], + network_sampling_setup: List, + net_keys_to_ids: Dict[str, int], + evaluator: bool = False, adder: Optional[adders.ParallelAdder] = None, + counts: Optional[Dict[str, Any]] = None, variable_client: Optional[tf2_variable_utils.VariableClient] = None, store_recurrent_state: bool = True, - trainer: MADQNTrainer = None, - communication_module: Optional[BaseCommunicationModule] = None, - fingerprint: bool = False, - evaluator: bool = False, interval: Optional[dict] = None, ): """Initialise the system executor - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - agent_net_keys (Dict[str, Any]): specifies what network each agent uses. - adder (Optional[adders.ParallelAdder], optional): adder which sends data + policy_networks: policy networks for each agent in + the system. + agent_specs: agent observation and action + space specifications. + agent_net_keys: specifies what network each agent uses. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + net_keys_to_ids: Specifies a mapping from network keys to their integer id. + adder: adder which sends data to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): + counts: Count values used to record excutor episode and steps. + variable_client: client to copy weights from the trainer. Defaults to None. - store_recurrent_state (bool, optional): boolean to store the recurrent + store_recurrent_state: boolean to store the recurrent network hidden state. Defaults to True. - trainer (MADQNTrainer, optional): system trainer. Defaults to None. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. Defaults to None. - fingerprint (bool, optional): whether to use fingerprint stabilisation to - stabilise experience replay. Defaults to False. - evaluator (bool, optional): whether the executor will be used for - evaluation. Defaults to False. + evaluator: whether the executor will be used for + evaluation. interval: interval that evaluations are run at. """ # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._q_networks = q_networks - self._policy_networks = q_networks - self._action_selectors = action_selectors - self._store_recurrent_state = store_recurrent_state - self._trainer = trainer - self._agent_net_keys = agent_net_keys - self._interval = interval + self._agent_specs = agent_specs + self._network_sampling_setup = network_sampling_setup + self._counts = counts + self._net_keys_to_ids = net_keys_to_ids + self._network_int_keys_extras: Dict[str, np.ndarray] = {} self._evaluator = evaluator - + self._interval = interval + self._value_networks = value_networks + self._agent_net_keys=agent_net_keys + self._adder=adder + self._variable_client=variable_client + self._store_recurrent_state=store_recurrent_state + self._observation_networks = observation_networks + self._action_selectors = action_selectors self._states: Dict[str, Any] = {} + @tf.function def _policy( self, agent: str, observation: types.NestedTensor, - state: types.NestedTensor, legal_actions: types.NestedTensor, - ) -> types.NestedTensor: + state: types.NestedTensor, + ) -> Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: """Agent specific policy function - Args: - agent (str): agent id - observation (types.NestedTensor): observation tensor received from the + agent: agent id + observation: observation tensor received from the environment. - state (types.NestedTensor): recurrent network state. - message (types.NestedTensor): received agent messsage. - legal_actions (types.NestedTensor): actions allowed to be taken at the - current observation. - + state: recurrent network state. + Raises: + NotImplementedError: unknown action space Returns: - types.NestedTensor: action and new recurrent hidden state + action, policy and new recurrent hidden state """ # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_observation = tf2_utils.add_batch_dim(observation) - batched_legals = tf2_utils.add_batch_dim(legal_actions) + batched_legal_actions = tf2_utils.add_batch_dim(legal_actions) # index network either on agent type or on agent id agent_key = self._agent_net_keys[agent] - # Compute the policy, conditioned on the observation. - q_values, new_state = self._q_networks[agent_key](batched_observation, state) + # Pass through observation network + embed = self._observation_networks[agent_key](batched_observation) - # select legal action - action = self._action_selectors[agent](q_values, batched_legals) + # Compute the policy, conditioned on the observation. + action_values, new_state = self._value_networks[agent_key](embed, state) + # Pass action values through action selector + action = self._action_selectors[agent](action_values, batched_legal_actions) + return action, new_state def select_action( self, agent: str, observation: types.NestedArray ) -> types.NestedArray: """select an action for a single agent in the system - Args: - agent (str): agent id - observation (types.NestedArray): observation tensor received from the + agent: agent id + observation: observation tensor received from the environment. - - Raises: - NotImplementedError: has not been implemented for this training type. + Returns: + action and policy. """ - policy_output, new_state = self._policy( - agent, - observation.observation, - self._states[agent], - observation.legal_actions, + # Initialize the RNN state if necessary. + if self._states[agent] is None: + # index network either on agent type or on agent id + agent_key = self._agent_net_keys[agent] + self._states[agent] = self._value_networks[agent_key].initia_state(1) + + # Step the recurrent policy forward given the current observation and state. + action, new_state = self._policy( + agent, observation.observation, observation.legal_actions, self._states[agent] ) - self._states[agent] = new_state + # Bookkeeping of recurrent states for the observe method. + self._update_state(agent, new_state) + + # Return a numpy array with squeezed out batch dimension. + action = tf2_utils.to_numpy_squeeze(action) - return tf2_utils.to_numpy_squeeze(policy_output) + return action def select_actions( - self, observations: Dict[str, OLT] - ) -> Dict[str, types.NestedArray]: + self, observations: Dict[str, types.NestedArray] + ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: """select the actions for all agents in the system - Args: - observations (Dict[str, OLT]): transition object containing observations, - legal actions and terminals. - + observations: agent observations from the + environment. Returns: - Dict[str, types.NestedArray]: actions for all agents in the system. + actions and policies for all agents in the system. """ actions = {} for agent, observation in observations.items(): actions[agent] = self.select_action(agent, observation) - - # Return a numpy array with squeezed out batch dimension. return actions - -class MADQNRecurrentCommExecutor(RecurrentCommExecutor, DQNExecutor): - """A recurrent executor with communication. - An executor based on a recurrent policy for each agent in the system using learned - communication. - """ - - def __init__( - self, - q_networks: Dict[str, snt.Module], - action_selectors: Dict[str, snt.Module], - communication_module: BaseCommunicationModule, - agent_net_keys: Dict[str, str], - adder: Optional[adders.ParallelAdder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - store_recurrent_state: bool = True, - trainer: MADQNTrainer = None, - fingerprint: bool = False, - evaluator: bool = False, - interval: Optional[dict] = None, - ): - """Initialise the system executor - - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - adder (Optional[adders.ParallelAdder], optional): adder which sends data - to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. - store_recurrent_state (bool, optional): boolean to store the recurrent - network hidden state. Defaults to True. - trainer (MADQNTrainer, optional): system trainer. Defaults to None. - fingerprint (bool, optional): whether to use fingerprint stabilisation to - stabilise experience replay. Defaults to False. - evaluator (bool, optional): whether the executor will be used for - evaluation. Defaults to False. - interval: interval that evaluations are run at. - """ - - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._q_networks = q_networks - self._policy_networks = q_networks - self._communication_module = communication_module - self._action_selectors = action_selectors - self._store_recurrent_state = store_recurrent_state - self._trainer = trainer - self._agent_net_keys = agent_net_keys - self._interval = interval - self._evaluator = evaluator - - self._states: Dict[str, Any] = {} - self._messages: Dict[str, Any] = {} - - @tf.function - def _policy( + def observe_first( self, - agent: str, - observation: types.NestedTensor, - state: types.NestedTensor, - message: types.NestedTensor, - legal_actions: types.NestedTensor, - ) -> types.NestedTensor: - """Agent specific policy function - - Args: - agent (str): agent id - observation (types.NestedTensor): observation tensor received from the - environment. - state (types.NestedTensor): Recurrent network state. - message (types.NestedTensor): received agent messsage. - legal_actions (types.NestedTensor): actions allowed to be taken at the - current observation. - - Returns: - types.NestedTensor: action and new recurrent hidden state - """ - - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_observation = tf2_utils.add_batch_dim(observation) - batched_legals = tf2_utils.add_batch_dim(legal_actions) - - # index network either on agent type or on agent id - agent_key = self._agent_net_keys[agent] - - # Compute the policy, conditioned on the observation. - (q_values, m_values), new_state = self._q_networks[agent_key]( - batched_observation, state, message - ) - - # select legal action - action = self._action_selectors[agent](q_values, batched_legals) - - return (action, m_values), new_state - - def select_action( - self, agent: str, observation: types.NestedArray - ) -> types.NestedArray: - """select an action for a single agent in the system + timestep: dm_env.TimeStep, + extras: Dict[str, types.NestedArray] = {}, + ) -> None: + """record first observed timestep from the environment Args: - agent (str): agent id - observation (types.NestedArray): observation tensor received from the - environment. - - Raises: - NotImplementedError: has not been implemented for this training type. + timestep: data emitted by an environment at first step of + interaction. + extras: possible extra information + to record during the first step. """ - message_inputs = self._communication_module.process_messages(self._messages) - (policy_output, new_message), new_state = self._policy( - agent, - observation.observation, - self._states[agent], - message_inputs[agent], - observation.legal_actions, + # Re-initialize the RNN state. + for agent, _ in timestep.observation.items(): + # index network either on agent type or on agent id + agent_key = self._agent_net_keys[agent] + self._states[agent] = self._value_networks[agent_key].initial_state(1) + + if not self._adder: + return + + # Sample new agent_net_keys. + agents = sort_str_num(list(self._agent_net_keys.keys())) + self._network_int_keys_extras, self._agent_net_keys = sample_new_agent_keys( + agents, + self._network_sampling_setup, + self._net_keys_to_ids, ) - self._states[agent] = new_state - self._messages[agent] = new_message - - return tf2_utils.to_numpy_squeeze(policy_output) - - def select_actions( - self, observations: Dict[str, OLT] - ) -> Dict[str, types.NestedArray]: - """select the actions for all agents in the system + if self._store_recurrent_state: + numpy_states = { + agent: tf2_utils.to_numpy_squeeze(_state) + for agent, _state in self._states.items() + } + extras.update({"core_states": numpy_states}) + extras["network_int_keys"] = self._network_int_keys_extras + self._adder.add_first(timestep, extras) + def observe( + self, + actions: Dict[str, types.NestedArray], + next_timestep: dm_env.TimeStep, + next_extras: Dict[str, types.NestedArray] = {}, + ) -> None: + """record observed timestep from the environment Args: - observations (Dict[str, OLT]): transition object containing observations, - legal actions and terminals. - - Returns: - Dict[str, types.NestedArray]: actions for all agents in the system. + actions: system agents' actions. + next_timestep: data emitted by an environment during + interaction. + next_extras: possible extra + information to record during the transition. """ - actions = {} + if not self._adder: + return - for agent, observation in observations.items(): - actions[agent] = self.select_action(agent, observation) + if self._store_recurrent_state: + numpy_states = { + agent: tf2_utils.to_numpy_squeeze(_state) + for agent, _state in self._states.items() + } + next_extras.update({"core_states": numpy_states}) + next_extras["network_int_keys"] = self._network_int_keys_extras + self._adder.add(actions, next_timestep, next_extras) # type: ignore - # Return a numpy array with squeezed out batch dimension. - return actions + def update(self, wait: bool = False) -> None: + """Update the policy variables.""" + if self._variable_client: + self._variable_client.get_async() diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index 519bc4033..50f5fea04 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -12,138 +12,127 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Dict, Mapping, Optional, Sequence, Union +import numpy as np import sonnet as snt +import tensorflow as tf from acme import types -from acme.tf.networks.atari import DQNAtariNetwork +from acme.tf import utils as tf2_utils +from dm_env import specs from mava import specs as mava_specs from mava.components.tf import networks -from mava.components.tf.networks.communication import CommunicationNetwork +from mava.utils.enums import ArchitectureType from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy -from mava.utils.enums import ArchitectureType, Network + +Array = specs.Array +BoundedArray = specs.BoundedArray +DiscreteArray = specs.DiscreteArray -# TODO Use fingerprints variable def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, agent_net_keys: Dict[str, str], - policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, + value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, archecture_type: ArchitectureType = ArchitectureType.feedforward, - network_type: Network = Network.mlp, - fingerprints: bool = False, - message_size: Optional[int] = None, seed: Optional[int] = None, ) -> Mapping[str, types.TensorTransformation]: - """Default networks for madqn. + """Default networks for maddpg. Args: - environment_spec (mava_specs.MAEnvironmentSpec): description of the action and + environment_spec: description of the action and observation spaces etc. for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): - size of policy networks. - archecture_type (ArchitectureType, optional): archecture used + agent_net_keys: specifies what network each agent uses. + vmin: hyperparameters for the distributional critic in mad4pg. + vmax: hyperparameters for the distributional critic in mad4pg. + net_spec_keys: specifies the specs of each network. + policy_networks_layer_sizes: size of policy networks. + critic_networks_layer_sizes: size of critic networks. + sigma: hyperparameters used to add Gaussian noise + for simple exploration. Defaults to 0.3. + archecture_type: archecture used for agent networks. Can be feedforward or recurrent. Defaults to ArchitectureType.feedforward. - network_type (Network, optional): Agent network type. - Can be mlp, atari_dqn_network or coms_network. - Defaults to Network.mlp. - fingerprints (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - message_size (Optional[int], optional): size of message passed, - if using a coms network. Defaults to None. - seed (int, optional): random seed for network initialization. + + num_atoms: hyperparameters for the distributional critic in + mad4pg. + seed: random seed for network initialization. Returns: - Mapping[str, types.TensorTransformation]: returned agent networks. + returned agent networks. """ - - # Set Policy function and layer size. + # Set Policy function and layer size # Default size per arch type. if archecture_type == ArchitectureType.feedforward: - if not policy_networks_layer_sizes: - policy_networks_layer_sizes = (512, 512, 256) - q_network_func = snt.Sequential + if not value_networks_layer_sizes: + value_networks_layer_sizes = ( + 256, + 256, + 256, + ) + value_network_func = snt.Sequential elif archecture_type == ArchitectureType.recurrent: - if not policy_networks_layer_sizes: - policy_networks_layer_sizes = (128, 128) - q_network_func = snt.DeepRNN + if not value_networks_layer_sizes: + value_networks_layer_sizes = (128, 64) + value_network_func = snt.DeepRNN - assert policy_networks_layer_sizes is not None - assert q_network_func is not None + assert value_networks_layer_sizes is not None + assert value_network_func is not None specs = environment_spec.get_agent_specs() # Create agent_type specs specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} - if isinstance(policy_networks_layer_sizes, Sequence): - policy_networks_layer_sizes = { - key: policy_networks_layer_sizes for key in specs.keys() + + if isinstance(value_networks_layer_sizes, Sequence): + value_networks_layer_sizes = { + key: value_networks_layer_sizes for key in specs.keys() + } + if isinstance(value_networks_layer_sizes, Sequence): + value_networks_layer_sizes = { + key: value_networks_layer_sizes for key in specs.keys() } - q_networks = {} + observation_networks = {} + value_networks = {} action_selectors = {} - for key in specs.keys(): + for key, spec in specs.items(): + num_actions = spec.actions.num_values - # Get total number of action dimensions from action spec. - num_dimensions = specs[key].actions.num_values + # An optional network to process observations + observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. - if network_type == Network.atari_dqn_network: - q_network = DQNAtariNetwork(num_dimensions) - elif network_type == Network.coms_network: - assert message_size is not None, "Message size not set." - q_network = CommunicationNetwork( - networks.LayerNormMLP((128,), activate_final=True, seed=seed), - networks.LayerNormMLP((128,), activate_final=True, seed=seed), - snt.LSTM(128), - snt.Sequential( - [ - networks.LayerNormMLP((128,), activate_final=True, seed=seed), - networks.NearZeroInitializedLinear(num_dimensions, seed=seed), - networks.TanhToSpec(specs[key].actions), - ] + if archecture_type == ArchitectureType.feedforward: + value_network = [ + networks.LayerNormMLP( + value_networks_layer_sizes[key], activate_final=True, seed=seed ), - snt.Sequential( - [ - networks.LayerNormMLP( - (128, message_size), activate_final=True, seed=seed - ), - ] + ] + elif archecture_type == ArchitectureType.recurrent: + value_network = [ + networks.LayerNormMLP( + value_networks_layer_sizes[key][:-1], + activate_final=True, + seed=seed, ), - message_size=message_size, - ) - else: - if archecture_type == ArchitectureType.feedforward: - q_network = [ - networks.LayerNormMLP( - list(policy_networks_layer_sizes[key]) + [num_dimensions], - activate_final=False, - seed=seed, - ), - ] - elif archecture_type == ArchitectureType.recurrent: - q_network = [ - networks.LayerNormMLP( - policy_networks_layer_sizes[key][:-1], - activate_final=True, - seed=seed, - ), - snt.LSTM(policy_networks_layer_sizes[key][-1]), - snt.Linear(num_dimensions), - ] - - q_network = q_network_func(q_network) - - q_networks[key] = q_network + snt.GRU(value_networks_layer_sizes[key][-1]), + ] + + value_network += [ + networks.NearZeroInitializedLinear(num_actions, seed=seed), + ] + + value_network = value_network_func(value_network) + + observation_networks[key] = observation_network + value_networks[key] = value_network action_selectors[key] = EpsilonGreedy return { - "q_networks": q_networks, + "values": value_networks, "action_selectors": action_selectors, + "observations": observation_networks, } diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 7eb5d9de0..065c7ac7e 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -16,7 +16,7 @@ """MADQN system implementation.""" import functools -from typing import Any, Callable, Dict, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Mapping import acme import dm_env @@ -25,31 +25,40 @@ import sonnet as snt from acme import specs as acme_specs from acme.tf import utils as tf2_utils -from acme.utils import counting +from acme.utils import loggers +from dm_env import specs import mava from mava import core from mava import specs as mava_specs -from mava.components.tf.architectures import DecentralisedValueActor -from mava.components.tf.modules.communication import BaseCommunicationModule +from mava.components.tf.architectures import ( + DecentralisedValueActor, +) from mava.components.tf.modules.exploration.exploration_scheduling import ( ConstantScheduler, ) -from mava.components.tf.modules.stabilising import FingerPrintStabalisation +from mava.types import EpsilonScheduler from mava.environment_loop import ParallelEnvironmentLoop from mava.systems.tf import executors -from mava.systems.tf import savers as tf2_savers -from mava.systems.tf.madqn import builder, execution, training -from mava.types import EpsilonScheduler -from mava.utils import lp_utils +from mava.systems.tf.madqn import builder, training +from mava.systems.tf.madqn.execution import ( + MADQNFeedForwardExecutor, + sample_new_agent_keys, +) +from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.utils import enums from mava.utils.loggers import MavaLogger, logger_utils +from mava.utils.sort_utils import sort_str_num from mava.wrappers import DetailedPerAgentStatistics class MADQN: """MADQN system.""" - def __init__( + """TODO: Implement faster adders to speed up training times when + using multiple trainers with non-shared weights.""" + + def __init__( # noqa self, environment_factory: Callable[[bool], dm_env.Environment], network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], @@ -59,36 +68,37 @@ def __init__( Mapping[str, Mapping[str, EpsilonScheduler]], ], logger_factory: Callable[[str], MavaLogger] = None, - architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, - trainer_fn: Union[ - Type[training.MADQNTrainer], Type[training.MADQNRecurrentTrainer] - ] = training.MADQNTrainer, - communication_module: Type[BaseCommunicationModule] = None, - executor_fn: Type[core.Executor] = execution.MADQNFeedForwardExecutor, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, + architecture: Type[ + DecentralisedValueActor + ] = DecentralisedValueActor, + trainer_fn: Type[training.MADQNTrainer] = training.MADQNTrainer, + executor_fn: Type[core.Executor] = MADQNFeedForwardExecutor, num_executors: int = 1, - num_caches: int = 0, - environment_spec: mava_specs.MAEnvironmentSpec = None, + trainer_networks: Union[ + Dict[str, List], enums.Trainer + ] = enums.Trainer.single_trainer, + network_sampling_setup: Union[ + List, enums.NetworkSampler + ] = enums.NetworkSampler.fixed_agent_networks, shared_weights: bool = True, - agent_net_keys: Dict[str, str] = {}, + environment_spec: mava_specs.MAEnvironmentSpec = None, + discount: float = 0.99, batch_size: int = 256, prefetch_size: int = 4, + target_averaging: bool = False, + target_update_period: int = 100, + target_update_rate: Optional[float] = None, + executor_variable_update_period: int = 1000, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: Optional[float] = 32.0, + optimizer: Union[ + snt.Optimizer, Dict[str, snt.Optimizer] + ] = snt.optimizers.Adam(learning_rate=1e-4), n_step: int = 5, sequence_length: int = 20, - importance_sampling_exponent: Optional[float] = None, - max_priority_weight: float = 0.9, period: int = 20, max_gradient_norm: float = None, - discount: float = 0.99, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( - learning_rate=1e-4 - ), - target_update_period: int = 100, - executor_variable_update_period: int = 1000, - max_executor_steps: int = None, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", checkpoint_minute_interval: int = 5, @@ -97,95 +107,107 @@ def __init__( eval_loop_fn: Callable = ParallelEnvironmentLoop, train_loop_fn_kwargs: Dict = {}, eval_loop_fn_kwargs: Dict = {}, + termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - seed: Optional[int] = None, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise the madqn system. - + """Initialise the system Args: - environment_factory : function to + environment_factory: function to instantiate an environment. - network_factory : function to instantiate system networks. - exploration_scheduler_fn : function specifying a decaying scheduler for epsilon exploration. - This can be - 1. The same across all agents & executors, - e.g. LinearExplorationTimestepScheduler(...), - 2. Or at an executor level, - e.g. see examples/debugging/simple_spread/feedforward/decentralised/run_madqn_configurable_epsilon.py # noqa: E501 - 3. Or at an agent level (same across all executors), - e.g. { "agent_0": LinearExplorationTimestepScheduler(...),"agent_1": LinearExplorationTimestepScheduler(...))}. # noqa: E501 - logger_factory : function to + network_factory: function to instantiate system networks. + logger_factory: function to instantiate a system logger. - architecture : system architecture, - e.g. decentralised ,centralised or networked. - trainer_fn : training type + architecture: + system architecture, e.g. decentralised or centralised. + trainer_fn: training type associated with executor and architecture, e.g. centralised training. - communication_module : module for enabling communication protocols between agents. - executor_fn : executor type, e.g. + executor_fn: executor type, e.g. feedforward or recurrent. - replay_stabilisation_fn : replay buffer stabilisaiton function, e.g. fingerprints. - num_executors : number of executor processes to run in - parallel. - num_caches : number of trainer node caches. - environment_spec : escription of - the action, observation spaces etc. - shared_weights : whether agents should share weights or not. - agent_net_keys : specifies what network each agent uses. - batch_size : sample batch size for updates. - prefetch_size : size to prefetch from replay. - min_replay_size : minimum replay size before updating. - max_replay_size : maximum replay size. - samples_per_insert : number of samples to take + num_executors: number of executor processes to run in + parallel.. + environment_spec: description of + the action, observation spaces etc. for each agent in the system. + trainer_networks: networks each + trainer trains on. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + enums.NetworkSampler settings: + fixed_agent_networks: Keeps the networks + used by each agent fixed throughout training. + random_agent_networks: Creates N network policies, where N is the + number of agents. Randomly select policies from this sets for each + agent at the start of a episode. This sampling is done with + replacement so the same policy can be selected for more than one + agent for a given episode. + Custom list: Alternatively one can specify a custom nested list, + with network keys in, that will be used by the executors at + the start of each episode to sample networks for each agent. + shared_weights: whether agents should share weights or not. + When network_sampling_setup are provided the value of shared_weights is + ignored. + discount: discount factor to use for TD updates. + batch_size: sample batch size for updates. + prefetch_size: size to prefetch from replay. + target_averaging: whether to use polyak averaging for + target network updates. + target_update_period: number of steps before target + networks are updated. + target_update_rate: update rate when using + averaging. + executor_variable_update_period: number of steps before + updating executor variables from the variable source. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take from replay for every insert that is made. - n_step : number of steps to include prior to boostrapping. - sequence_length : recurrent sequence rollout length. - importance_sampling_exponent : value of importance sampling - exponent (usually around 0.2). - max_priority_weight : Required if importance_sampling_exponent - is not None. - period : consecutive starting points for overlapping + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic + networks. + n_step: number of steps to include prior to boostrapping. + sequence_length: recurrent sequence rollout length. + period: Consecutive starting points for overlapping rollouts across a sequence. - max_gradient_norm : maximum allowed norm for gradients + max_gradient_norm: maximum allowed norm for gradients before clipping is applied. - discount : discount factor to use for TD updates. - optimizer : type of optimizer to use to update network parameters. - target_update_period : number of steps before target - networks are updated. - executor_variable_update_period : number of steps before - updating executor variables from the variable source. - max_executor_steps : maximum number of steps and executor - can in an episode. - checkpoint : whether to checkpoint models. - checkpoint_subpath : subdirectory specifying where to store + checkpoint: whether to checkpoint models. + checkpoint_minute_interval: The number of minutes to wait between checkpoints. - checkpoint_minute_interval : The number of minutes to wait between + checkpoint_subpath: subdirectory specifying where to store checkpoints. - logger_config : additional configuration settings for the + logger_config: additional configuration settings for the logger factory. - train_loop_fn : function to instantiate a train loop. - eval_loop_fn : function to instantiate a eval loop. - train_loop_fn_kwargs : possible keyword arguments to send + train_loop_fn: function to instantiate a train loop. + eval_loop_fn: function to instantiate an evaluation + loop. + train_loop_fn_kwargs: possible keyword arguments to send to the training loop. - eval_loop_fn_kwargs :possible keyword arguments to send to - the evaluation loop. - learning_rate_scheduler_fn : function/class that takes in a trainer step t - and returns the current learning rate. - seed: seed for reproducible sampling (used for epsilon greedy action selection). + eval_loop_fn_kwargs: possible keyword arguments to send to + the evaluation loop. + termination_condition: An optional terminal condition can be + provided that stops the program once the condition is + satisfied. Available options include specifying maximum + values for trainer_steps, trainer_walltime, evaluator_steps, + evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. + learning_rate_scheduler_fn: dict with two functions/classes (one for the + policy and one for the critic optimizer), that takes in a trainer + step t and returns the current learning rate, + e.g. {"policy": policy_lr_schedule ,"critic": critic_lr_schedule}. + See + examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py + for an example. evaluator_interval: An optional condition that is used to evaluate/test system performance after [evaluator_interval] condition has been met. If None, evaluation will happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - Raises: - ValueError: [description] - """ if not environment_spec: environment_spec = mava_specs.MAEnvironmentSpec( - environment_factory(evaluation=False) # type:ignore + environment_factory(evaluation=False) # type: ignore ) # set default logger if no logger provided @@ -197,42 +219,120 @@ def __init__( time_delta=10, ) - self._architecture = architecture - self._communication_module_fn = communication_module - self._environment_factory = environment_factory - self._network_factory = network_factory - self._logger_factory = logger_factory - self._environment_spec = environment_spec - # Setup agent networks - self._agent_net_keys = agent_net_keys - if not agent_net_keys: - agents = environment_spec.get_agent_ids() - self._agent_net_keys = { - agent: agent.split("_")[0] if shared_weights else agent - for agent in agents - } - self._num_exectors = num_executors - self._num_caches = num_caches + # Setup agent networks and network sampling setup + agents = sort_str_num(environment_spec.get_agent_ids()) + self._network_sampling_setup = network_sampling_setup + + if type(network_sampling_setup) is not list: + if network_sampling_setup == enums.NetworkSampler.fixed_agent_networks: + # if no network_sampling_setup is fixed, use shared_weights to + # determine setup + self._agent_net_keys = { + agent: "network_0" if shared_weights else f"network_{i}" + for i, agent in enumerate(agents) + } + self._network_sampling_setup = [ + [ + self._agent_net_keys[key] + for key in sort_str_num(self._agent_net_keys.keys()) + ] + ] + elif network_sampling_setup == enums.NetworkSampler.random_agent_networks: + """Create N network policies, where N is the number of agents. Randomly + select policies from this sets for each agent at the start of a + episode. This sampling is done with replacement so the same policy + can be selected for more than one agent for a given episode.""" + if shared_weights: + raise ValueError( + "Shared weights cannot be used with random policy per agent" + ) + self._agent_net_keys = { + agents[i]: f"network_{i}" for i in range(len(agents)) + } + self._network_sampling_setup = [ + [ + [self._agent_net_keys[key]] + for key in sort_str_num(self._agent_net_keys.keys()) + ] + ] + else: + raise ValueError( + "network_sampling_setup must be a dict or fixed_agent_networks" + ) - self._max_executor_steps = max_executor_steps - self._checkpoint_subpath = checkpoint_subpath - self._checkpoint = checkpoint - self._logger_config = logger_config - self._train_loop_fn = train_loop_fn - self._train_loop_fn_kwargs = train_loop_fn_kwargs - self._eval_loop_fn = eval_loop_fn - self._eval_loop_fn_kwargs = eval_loop_fn_kwargs - self._checkpoint_minute_interval = checkpoint_minute_interval - self._seed = seed - self._evaluator_interval = evaluator_interval + else: + # if a dictionary is provided, use network_sampling_setup to determine setup + _, self._agent_net_keys = sample_new_agent_keys( + agents, + self._network_sampling_setup, # type: ignore + ) - if issubclass(executor_fn, executors.RecurrentExecutor): - extra_specs = self._get_extra_specs() + + # Check that the environment and agent_net_keys has the same amount of agents + sample_length = len(self._network_sampling_setup[0]) # type: ignore + assert len(environment_spec.get_agent_ids()) == len(self._agent_net_keys.keys()) + + # Check if the samples are of the same length and that they perfectly fit + # into the total number of agents + assert len(self._agent_net_keys.keys()) % sample_length == 0 + for i in range(1, len(self._network_sampling_setup)): # type: ignore + assert len(self._network_sampling_setup[i]) == sample_length # type: ignore + + # Get all the unique agent network keys + all_samples = [] + for sample in self._network_sampling_setup: # type: ignore + all_samples.extend(sample) + unique_net_keys = list(sort_str_num(list(set(all_samples)))) + + # Create mapping from ints to networks + net_keys_to_ids = {net_key: i for i, net_key in enumerate(unique_net_keys)} + + # Setup trainer_networks + if type(trainer_networks) is not dict: + if trainer_networks == enums.Trainer.single_trainer: + self._trainer_networks = {"trainer": unique_net_keys} + elif trainer_networks == enums.Trainer.one_trainer_per_network: + self._trainer_networks = { + f"trainer_{i}": [unique_net_keys[i]] + for i in range(len(unique_net_keys)) + } + else: + raise ValueError( + "trainer_networks does not support this enums setting." + ) else: - extra_specs = {} + self._trainer_networks = trainer_networks # type: ignore + + # Get all the unique trainer network keys + all_trainer_net_keys = [] + for trainer_nets in self._trainer_networks.values(): + all_trainer_net_keys.extend(trainer_nets) + unique_trainer_net_keys = sort_str_num(list(set(all_trainer_net_keys))) + + # Check that all agent_net_keys are in trainer_networks + assert unique_net_keys == unique_trainer_net_keys + # Setup specs for each network + self._net_spec_keys = {} + for i in range(len(unique_net_keys)): + self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] + + # Setup table_network_config + table_network_config = {} + for trainer_key in self._trainer_networks.keys(): + most_matches = 0 + trainer_nets = self._trainer_networks[trainer_key] + for sample in self._network_sampling_setup: # type: ignore + matches = 0 + for entry in sample: + if entry in trainer_nets: + matches += 1 + if most_matches < matches: + matches = most_matches + table_network_config[trainer_key] = sample # Setup epsilon schedules # If we receive a single schedule, we use that for all agents. + self._num_exectors = num_executors if not isinstance(exploration_scheduler_fn, dict): self._exploration_scheduler_fn: Dict = {} for executor_id in range(self._num_exectors): @@ -261,255 +361,163 @@ def __init__( + f" level config: {exploration_scheduler_fn}" ) + self._table_network_config = table_network_config + self._architecture = architecture + self._environment_factory = environment_factory + self._network_factory = network_factory + self._logger_factory = logger_factory + self._environment_spec = environment_spec + self._checkpoint_subpath = checkpoint_subpath + self._checkpoint = checkpoint + self._logger_config = logger_config + self._train_loop_fn = train_loop_fn + self._train_loop_fn_kwargs = train_loop_fn_kwargs + self._eval_loop_fn = eval_loop_fn + self._eval_loop_fn_kwargs = eval_loop_fn_kwargs + self._evaluator_interval = evaluator_interval + + extra_specs = {} + if issubclass(executor_fn, executors.RecurrentExecutor): + extra_specs = self._get_extra_specs() + + int_spec = specs.DiscreteArray(len(unique_net_keys)) + agents = environment_spec.get_agent_ids() + net_spec = {"network_keys": {agent: int_spec for agent in agents}} + extra_specs.update(net_spec) + + self._builder = builder.MADQNBuilder( builder.MADQNConfig( environment_spec=environment_spec, agent_net_keys=self._agent_net_keys, + trainer_networks=self._trainer_networks, + table_network_config=table_network_config, + num_executors=num_executors, + network_sampling_setup=self._network_sampling_setup, # type: ignore + net_keys_to_ids=net_keys_to_ids, + unique_net_keys=unique_net_keys, discount=discount, batch_size=batch_size, prefetch_size=prefetch_size, + target_averaging=target_averaging, target_update_period=target_update_period, + target_update_rate=target_update_rate, executor_variable_update_period=executor_variable_update_period, min_replay_size=min_replay_size, max_replay_size=max_replay_size, samples_per_insert=samples_per_insert, n_step=n_step, sequence_length=sequence_length, - importance_sampling_exponent=importance_sampling_exponent, - max_priority_weight=max_priority_weight, period=period, max_gradient_norm=max_gradient_norm, checkpoint=checkpoint, optimizer=optimizer, checkpoint_subpath=checkpoint_subpath, checkpoint_minute_interval=checkpoint_minute_interval, + termination_condition=termination_condition, evaluator_interval=evaluator_interval, learning_rate_scheduler_fn=learning_rate_scheduler_fn, ), trainer_fn=trainer_fn, executor_fn=executor_fn, extra_specs=extra_specs, - replay_stabilisation_fn=replay_stabilisation_fn, ) def _get_extra_specs(self) -> Any: - """Helper to establish specs for extra information. - + """helper to establish specs for extra information Returns: dictionary containing extra specs """ agents = self._environment_spec.get_agent_ids() core_state_specs = {} - core_message_specs = {} - networks = self._network_factory( # type: ignore environment_spec=self._environment_spec, agent_net_keys=self._agent_net_keys, ) for agent in agents: - agent_type = agent.split("_")[0] + agent_net_key = self._agent_net_keys[agent] core_state_specs[agent] = ( tf2_utils.squeeze_batch_dim( - networks["q_networks"][agent_type].initial_state(1) + networks["values"][agent_net_key].initial_state(1) ), ) - if self._communication_module_fn is not None: - core_message_specs[agent] = ( - tf2_utils.squeeze_batch_dim( - networks["q_networks"][agent_type].initial_message(1) - ), - ) - - extras = { - "core_states": core_state_specs, - "core_messages": core_message_specs, - } - return extras + return {"core_states": core_state_specs} def replay(self) -> Any: - """Replay data storage. - - Returns: - Any: replay data table built according the environment specification. - """ - - return self._builder.make_replay_tables(self._environment_spec) - - def counter(self, checkpoint: bool) -> Any: """Step counter - - Args: - checkpoint (bool): whether to checkpoint the counter. - - Returns: - Any: checkpointing object logging steps in a counter subdirectory. - """ - - if checkpoint: - return tf2_savers.CheckpointingRunner( - counting.Counter(), - time_delta_minutes=self._checkpoint_minute_interval, - directory=self._checkpoint_subpath, - subdirectory="counter", - ) - else: - return counting.Counter() - - def coordinator(self, counter: counting.Counter) -> Any: - """Coordination helper for a distributed program - Args: - counter (counting.Counter): step counter object. - + checkpoint: whether to checkpoint the counter. Returns: - Any: step limiter object. + step counter object. """ + return self._builder.make_replay_tables(self._environment_spec) - return lp_utils.StepsLimiter(counter, self._max_executor_steps) # type: ignore - - def trainer( + def create_system( self, - replay: reverb.Client, - counter: counting.Counter, - ) -> mava.core.Trainer: - """System trainer - - Args: - replay (reverb.Client): replay data table to pull data from. - counter (counting.Counter): step counter object. - - Returns: - mava.core.Trainer: system trainer. - """ - + ) -> Tuple[Dict[str, Dict[str, snt.Module]], Dict[str, Dict[str, snt.Module]]]: + """Initialise the system variables from the network factory.""" # Create the networks to optimize (online) networks = self._network_factory( # type: ignore environment_spec=self._environment_spec, agent_net_keys=self._agent_net_keys, ) - # Create system architecture with target networks. - architecture = self._architecture( - environment_spec=self._environment_spec, - value_networks=networks["q_networks"], - agent_net_keys=self._agent_net_keys, - ) - - if self._builder._replay_stabiliser_fn is not None: - architecture = self._builder._replay_stabiliser_fn( # type: ignore - architecture - ) - communication_module = None - if self._communication_module_fn is not None: - communication_module = self._communication_module_fn( - architecture=architecture, - shared=True, - channel_size=1, - channel_noise=0, - ) - system_networks = communication_module.create_system() - else: - system_networks = architecture.create_system() + # architecture args + architecture_config = { + "environment_spec": self._environment_spec, + "observation_networks": networks["observations"], + "value_networks": networks["values"], + "action_selectors": networks["action_selectors"], + "agent_net_keys": self._agent_net_keys, + } - # create logger - trainer_logger_config = {} - if self._logger_config and "trainer" in self._logger_config: - trainer_logger_config = self._logger_config["trainer"] - trainer_logger = self._logger_factory( # type: ignore - "trainer", **trainer_logger_config - ) + system = self._architecture(**architecture_config) + networks = system.create_system() - dataset = self._builder.make_dataset_iterator(replay) - counter = counting.Counter(counter, "trainer") + return networks - return self._builder.make_trainer( - networks=system_networks, - dataset=dataset, - replay_client=replay, - counter=counter, - communication_module=communication_module, - logger=trainer_logger, - ) + def variable_server(self) -> MavaVariableSource: + """Create the variable server.""" + # Create the system + networks = self.create_system() + return self._builder.make_variable_server(networks) def executor( self, executor_id: str, replay: reverb.Client, variable_source: acme.VariableSource, - counter: counting.Counter, - trainer: Optional[ - Union[training.MADQNTrainer, training.MADQNRecurrentTrainer] - ] = None, ) -> mava.ParallelEnvironmentLoop: """System executor - Args: - executor_id (str): id to identify the executor process for logging purposes. - replay (reverb.Client): replay data table to push data to. - variable_source (acme.VariableSource): variable server for updating + executor_id: id to identify the executor process for logging purposes. + replay: replay data table to push data to. + variable_source: variable server for updating network variables. - counter (counting.Counter): step counter object. - trainer (Optional[training.MADQNRecurrentCommTrainer], optional): - system trainer. Defaults to None. - Returns: mava.ParallelEnvironmentLoop: environment-executor loop instance. """ - # Create the behavior policy. - networks = self._network_factory( # type: ignore - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - ) - - # Create system architecture with target networks. - architecture = self._architecture( - environment_spec=self._environment_spec, - value_networks=networks["q_networks"], - agent_net_keys=self._agent_net_keys, - ) - - if self._builder._replay_stabiliser_fn is not None: - architecture = self._builder._replay_stabiliser_fn( # type: ignore - architecture - ) - - communication_module = None - if self._communication_module_fn is not None: - communication_module = self._communication_module_fn( - architecture=architecture, - shared=True, - channel_size=1, - channel_noise=0, - ) - system_networks = communication_module.create_system() - else: - system_networks = architecture.create_system() + # Create the system + networks = self.create_system() # Create the executor. executor = self._builder.make_executor( - q_networks=system_networks["values"], - action_selectors=networks["action_selectors"], - communication_module=communication_module, - adder=self._builder.make_adder(replay), - variable_source=variable_source, - trainer=trainer, - evaluator=False, + networks=networks, exploration_schedules=self._exploration_scheduler_fn[ f"executor_{executor_id}" ], - seed=self._seed, + adder=self._builder.make_adder(replay), + variable_source=variable_source, + evaluator=False, ) # TODO (Arnu): figure out why factory function are giving type errors # Create the environment. environment = self._environment_factory(evaluation=False) # type: ignore - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, "executor") - # Create executor logger executor_logger_config = {} if self._logger_config and "executor" in self._logger_config: @@ -522,7 +530,6 @@ def executor( train_loop = self._train_loop_fn( environment, executor, - counter=counter, logger=exec_logger, **self._train_loop_fn_kwargs, ) @@ -534,72 +541,36 @@ def executor( def evaluator( self, variable_source: acme.VariableSource, - counter: counting.Counter, - trainer: training.MADQNTrainer, + logger: loggers.Logger = None, ) -> Any: - """System evaluator (an executor process not connected to a dataset). - + """System evaluator (an executor process not connected to a dataset) Args: - variable_source : variable server for updating + variable_source: variable server for updating network variables. - counter : step counter object. - trainer : system trainer. - + logger: logger object. Returns: environment-executor evaluation loop instance for evaluating the performance of a system. """ - # Create the behavior policy. - networks = self._network_factory( # type: ignore - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - ) - - # Create system architecture with target networks. - architecture = self._architecture( - environment_spec=self._environment_spec, - value_networks=networks["q_networks"], - agent_net_keys=self._agent_net_keys, - ) - - if self._builder._replay_stabiliser_fn is not None: - architecture = self._builder._replay_stabiliser_fn( # type: ignore - architecture - ) - - communication_module = None - if self._communication_module_fn is not None: - communication_module = self._communication_module_fn( - architecture=architecture, - shared=True, - channel_size=1, - channel_noise=0, - ) - system_networks = communication_module.create_system() - else: - system_networks = architecture.create_system() + # Create the system + networks = self.create_system() # Create the agent. executor = self._builder.make_executor( - q_networks=system_networks["values"], - action_selectors=networks["action_selectors"], - variable_source=variable_source, - communication_module=communication_module, - trainer=trainer, evaluator=True, + networks=networks, exploration_schedules={ agent: ConstantScheduler(epsilon=0.0) for agent in self._environment_spec.get_agent_ids() }, - seed=self._seed, + variable_source=variable_source, ) # Make the environment. environment = self._environment_factory(evaluation=True) # type: ignore # Create logger and counter. - counter = counting.Counter(counter, "evaluator") evaluator_logger_config = {} if self._logger_config and "evaluator" in self._logger_config: evaluator_logger_config = self._logger_config["evaluator"] @@ -612,7 +583,6 @@ def evaluator( eval_loop = self._eval_loop_fn( environment, executor, - counter=counter, logger=eval_logger, **self._eval_loop_fn_kwargs, ) @@ -620,62 +590,74 @@ def evaluator( eval_loop = DetailedPerAgentStatistics(eval_loop) return eval_loop - def build(self, name: str = "madqn") -> Any: - """Build the distributed system as a graph program. - + def trainer( + self, + trainer_id: str, + replay: reverb.Client, + variable_source: MavaVariableSource, + ) -> mava.core.Trainer: + """System trainer Args: - name (str, optional): system name. Defaults to "madqn". - + trainer_id: Id of the trainer being created. + replay: replay data table to pull data from. + variable_source: variable server for updating + network variables. Returns: - Any: graph program for distributed system training. + system trainer. """ + # create logger + trainer_logger_config = {} + if self._logger_config and "trainer" in self._logger_config: + trainer_logger_config = self._logger_config["trainer"] + trainer_logger = self._logger_factory( # type: ignore + trainer_id, **trainer_logger_config + ) + + # Create the system + networks = self.create_system() + + dataset = self._builder.make_dataset_iterator(replay, trainer_id) + + return self._builder.make_trainer( + networks=networks, + trainer_networks=self._trainer_networks[trainer_id], + trainer_table_entry=self._table_network_config[trainer_id], + dataset=dataset, + logger=trainer_logger, + variable_source=variable_source, + ) + + def build(self, name: str = "maddpg") -> Any: + """Build the distributed system as a graph program. + Args: + name: system name. + Returns: + graph program for distributed system training. + """ program = lp.Program(name=name) with program.group("replay"): replay = program.add_node(lp.ReverbNode(self.replay)) - with program.group("counter"): - counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint)) - - if self._max_executor_steps: - with program.group("coordinator"): - _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + with program.group("variable_server"): + variable_server = program.add_node(lp.CourierNode(self.variable_server)) with program.group("trainer"): - trainer = program.add_node(lp.CourierNode(self.trainer, replay, counter)) + # Add executors which pull round-robin from our variable sources. + for trainer_id in self._trainer_networks.keys(): + program.add_node( + lp.CourierNode(self.trainer, trainer_id, replay, variable_server) + ) with program.group("evaluator"): - program.add_node(lp.CourierNode(self.evaluator, trainer, counter, trainer)) - - if not self._num_caches: - # Use the trainer as a single variable source. - sources = [trainer] - else: - with program.group("cacher"): - # Create a set of trainer caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - trainer, refresh_interval_ms=2000, stale_after_ms=4000 - ) - ) - sources.append(cacher) + program.add_node(lp.CourierNode(self.evaluator, variable_server)) with program.group("executor"): # Add executors which pull round-robin from our variable sources. for executor_id in range(self._num_exectors): - source = sources[executor_id % len(sources)] program.add_node( - lp.CourierNode( - self.executor, - executor_id, - replay, - source, - counter, - trainer, - ) + lp.CourierNode(self.executor, executor_id, replay, variable_server) ) return program diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 1a81bcc89..3b4f46495 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MADQN system trainer implementation.""" + +"""MADQN trainer implementation.""" import copy import time @@ -25,15 +26,16 @@ import tensorflow as tf import tree import trfl +from acme.tf import losses from acme.tf import utils as tf2_utils -from acme.types import NestedArray -from acme.utils import counting, loggers +from acme.utils import loggers import mava from mava import types as mava_types -from mava.adders import reverb as reverb_adders -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf import savers as tf2_savers +from mava.adders.reverb.base import Trajectory +from mava.components.tf.losses.sequence import recurrent_n_step_critic_loss +from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor +from mava.systems.tf.variable_utils import VariableClient from mava.utils import training_utils as train_utils from mava.utils.sort_utils import sort_str_num @@ -42,7 +44,7 @@ class MADQNTrainer(mava.Trainer): """MADQN trainer. - This is the trainer component of a MADQN system. IE it takes a dataset as input + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -50,108 +52,108 @@ def __init__( self, agents: List[str], agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], + value_networks: Dict[str, snt.Module], + target_value_networks: Dict[str, snt.Module], + optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + discount: float, + target_averaging: bool, target_update_period: int, + target_update_rate: float, dataset: tf.data.Dataset, - optimizer: Union[Dict[str, snt.Optimizer], snt.Optimizer], - discount: float, + observation_networks: Dict[str, snt.Module], + target_observation_networks: Dict[str, snt.Module], + variable_client: VariableClient, + counts: Dict[str, Any], agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, max_gradient_norm: float = None, - importance_sampling_exponent: Optional[float] = None, - replay_client: Optional[reverb.TFClient] = None, - max_priority_weight: float = 0.9, - fingerprint: bool = False, - counter: counting.Counter = None, logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, - communication_module: Optional[BaseCommunicationModule] = None, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise MADQN trainer - + """Initialise MADDPG trainer Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - counter (counting.Counter, optional): step counter object. Defaults to None. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - communication_module (BaseCommunicationModule): module for communication - between agents. Defaults to None. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: + optimizer(s) for updating policy networks. + critic_optimizer: + optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. """ self._agents = agents self._agent_types = agent_types self._agent_net_keys = agent_net_keys - self._checkpoint = checkpoint + self._variable_client = variable_client self._learning_rate_scheduler_fn = learning_rate_scheduler_fn - # Store online and target q-networks. - self._q_networks = q_networks - self._target_q_networks = target_q_networks + # Setup counts + self._counts = counts + + # Store online and target networks. + self._value_networks = value_networks + self._target_value_networks = target_value_networks + + # Ensure obs and target networks are sonnet modules + self._observation_networks = { + k: tf2_utils.to_sonnet_module(v) for k, v in observation_networks.items() + } + self._target_observation_networks = { + k: tf2_utils.to_sonnet_module(v) + for k, v in target_observation_networks.items() + } # General learner book-keeping and loggers. - self._counter = counter or counting.Counter() - self._logger = logger + self._logger = logger or loggers.make_default_logger("trainer") # Other learner parameters. self._discount = discount + # Set up gradient clipping. if max_gradient_norm is not None: self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) else: # A very large number. Infinity results in NaNs. self._max_gradient_norm = tf.convert_to_tensor(1e10) - self._fingerprint = fingerprint - # Necessary to track when to update target networks. - self._num_steps = tf.Variable(0, trainable=False) + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_averaging = target_averaging self._target_update_period = target_update_period + self._target_update_rate = target_update_rate # Create an iterator to go through the dataset. - self._iterator = dataset - - # Importance sampling hyper-parameters - self._max_priority_weight = max_priority_weight - self._importance_sampling_exponent = importance_sampling_exponent + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - # Replay client for updating priorities. - self._replay_client = replay_client - self._replay_table_name = replay_table_name + # Dictionary with unique network keys. + self.unique_net_keys = sort_str_num(self._value_networks.keys()) - # NOTE We make replay_client optional to make changes to MADQN trainer - # compatible with the other systems that inherit from it (VDN, QMIX etc.) - # TODO Include importance sampling in the other systems so that we can remove - # this check. - if self._importance_sampling_exponent is not None: - assert isinstance(self._replay_client, reverb.Client) - - # Dictionary with network keys for each agent. - self.unique_net_keys = sort_str_num(self._q_networks.keys()) + # Get the agents which shoud be updated and ran + self._trainer_agent_list = self._agents # Create optimizers for different agent types. if not isinstance(optimizer, dict): @@ -162,206 +164,123 @@ def __init__( self._optimizers = optimizer # Expose the variables. - q_networks_to_expose = {} self._system_network_variables: Dict[str, Dict[str, snt.Module]] = { - "q_network": {}, + "observations": {}, + "values": {}, } for agent_key in self.unique_net_keys: - q_network_to_expose = self._target_q_networks[agent_key] - - q_networks_to_expose[agent_key] = q_network_to_expose - - self._system_network_variables["q_network"][ + self._system_network_variables["observations"][ agent_key - ] = q_network_to_expose.variables - - # Checkpointer - self._system_checkpointer = {} - if checkpoint: - for agent_key in self.unique_net_keys: - - checkpointer = tf2_savers.Checkpointer( - directory=checkpoint_subpath, - time_delta_minutes=checkpoint_minute_interval, - objects_to_save={ - "counter": self._counter, - "q_network": self._q_networks[agent_key], - "target_q_network": self._target_q_networks[agent_key], - "optimizer": self._optimizers, - "num_steps": self._num_steps, - }, - enable_checkpointing=checkpoint, - ) - - self._system_checkpointer[agent_key] = checkpointer + ] = self._target_observation_networks[agent_key].variables + self._system_network_variables["values"][ + agent_key + ] = self._value_networks[agent_key].variables # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. - self._timestamp: Optional[float] = None - def get_trainer_steps(self) -> float: - """get trainer step count - - Returns: - float: number of trainer steps - """ - - return self._num_steps.numpy() - def _update_target_networks(self) -> None: - """Sync the target network parameters with the latest online network - parameters""" - + """Update the target networks using either target averaging or + by directy copying the weights of the online networks every few steps.""" for key in self.unique_net_keys: # Update target network. - online_variables = (*self._q_networks[key].variables,) - - target_variables = (*self._target_q_networks[key].variables,) + online_variables = ( + *self._observation_networks[key].variables, + *self._value_networks[key].variables, + ) + target_variables = ( + *self._target_observation_networks[key].variables, + *self._target_value_networks[key].variables, + ) - # Make online -> target network update ops. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: + if self._target_averaging: + assert 0.0 < self._target_update_rate < 1.0 + tau = self._target_update_rate for src, dest in zip(online_variables, target_variables): - dest.assign(src) + dest.assign(dest * (1.0 - tau) + src * tau) + else: + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) self._num_steps.assign_add(1) - def _update_sample_priorities(self, keys: tf.Tensor, priorities: tf.Tensor) -> None: - """Update sample priorities in replay table using importance weights. + def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: + """Depricated""" + pass + + def _transform_observations( + self, obs: Dict[str, mava_types.OLT], next_obs: Dict[str, mava_types.OLT] + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """Transform the observatations using the observation networks of each agent." Args: - keys (tf.Tensor): Keys of the replay samples. - priorities (tf.Tensor): New priorities for replay samples. + obs: observations at timestep t-1 + next_obs: observations at timestep t + Returns: + Transformed observatations """ - # Maybe update the sample priorities in the replay buffer. - if ( - self._importance_sampling_exponent is not None - and self._replay_client is not None - ): - self._replay_client.mutate_priorities( - table=self._replay_table_name, - updates=dict(zip(keys.numpy(), priorities.numpy())), + o_tm1 = {} + o_t = {} + for agent in self._agents: + agent_key = self._agent_net_keys[agent] + o_tm1[agent] = self._observation_networks[agent_key](obs[agent].observation) + o_t[agent] = self._target_observation_networks[agent_key]( + next_obs[agent].observation ) + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + o_t[agent] = tree.map_structure(tf.stop_gradient, o_t[agent]) + return o_tm1, o_t - def _get_feed( + @tf.function + def _step( self, - o_tm1_trans: Dict[str, mava_types.OLT], - o_t_trans: Dict[str, mava_types.OLT], - a_tm1: Dict[str, np.ndarray], - agent: str, - ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """get data to feed to the agent networks - - Args: - o_tm1_trans (Dict[str, np.ndarray]): transformed (e.g. using observation - network) observation at timestep t-1 - o_t_trans (Dict[str, np.ndarray]): transformed observation at timestep t - a_tm1 (Dict[str, np.ndarray]): action at timestep t-1 - agent (str): agent id + ) -> Dict[str, Dict[str, Any]]: + """Trainer forward and backward passes. Returns: - Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: agent network feeds, observations - at t-1, t and action at time t. + losses """ - # Decentralised - o_tm1_feed = o_tm1_trans[agent].observation - o_t_feed = o_t_trans[agent].observation - a_tm1_feed = a_tm1[agent] - - return o_tm1_feed, o_t_feed, a_tm1_feed - - def step(self) -> None: - """trainer step to update the parameters of the agents in the system""" - - # Run the learning step. - fetches = self._step() - - # Compute elapsed time. - timestamp = time.time() - if self._timestamp: - elapsed_time = timestamp - self._timestamp - else: - elapsed_time = 0 - self._timestamp = timestamp # type: ignore - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) - - # Checkpoint and attempt to write the logs. - if self._checkpoint: - train_utils.checkpoint_networks(self._system_checkpointer) - - if self._logger: - self._logger.write(fetches) - - @tf.function - def _forward_backward(self) -> Tuple: - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - - self._forward(inputs) - - self._backward() - - extras = {} - - if self._importance_sampling_exponent is not None: - extras.update( - {"keys": self._sample_keys, "priorities": self._sample_priorities} - ) - - # Return Q-value losses. - fetches = self._q_network_losses - - return fetches, extras - - @tf.function - def _step(self) -> Dict: - """Trainer forward and backward passes.""" - # Update the target networks self._update_target_networks() - fetches, extras = self._forward_backward() + # Draw a batch of data from replay. + sample: reverb.ReplaySample = next(self._iterator) + + self._forward(sample) - # Maybe update priorities. - # NOTE _update_sample_priorities must happen outside of - # tf.function. That is why we seperate out forward_backward(). - if self._importance_sampling_exponent is not None: - self._update_sample_priorities(extras["keys"], extras["priorities"]) + self._backward() - # Log losses - return fetches + # Log losses per agent + return train_utils.map_losses_per_agent_value( + self.value_losses + ) + # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass - Args: - inputs (Any): input data from the data table (transitions) + inputs: input data from the data table (transitions) """ - # Get info about the samples from reverb. - sample_info = inputs.info - sample_keys = tf.transpose(inputs.info.key) - sample_probs = tf.transpose(sample_info.probability) - - # Initialize sample priorities at zero. - sample_priorities = np.zeros(len(inputs.info.key)) - # Unpack input data as follows: # o_tm1 = dictionary of observations one for each agent # a_tm1 = dictionary of actions taken from obs in o_tm1 + # e_tm1 [Optional] = extra data for timestep t-1 + # that the agents persist in replay. # r_t = dictionary of rewards or rewards sequences # (if using N step transitions) ensuing from actions a_tm1 # d_t = environment discount ensuing from actions a_tm1. # This discount is applied to future rewards after r_t. # o_t = dictionary of next observations or next observation sequences - # e_t [Optional] = extra data that the agents persist in replay. + # e_t [Optional] = extra data for timestep t that the agents persist in replay. trans = mava_types.Transition(*inputs.data) - o_tm1, o_t, a_tm1, r_t, d_t, e_tm1, e_t = ( trans.observations, trans.next_observations, @@ -372,124 +291,72 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: trans.next_extras, ) + self.value_losses = {} + # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: - q_network_losses: Dict[str, NestedArray] = {} - for agent in self._agents: + o_tm1_trans, o_t_trans = self._transform_observations(o_tm1, o_t) + for agent in self._trainer_agent_list: agent_key = self._agent_net_keys[agent] - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=d_t[agent].dtype) + # Double Q-learning + q_tm1 = self._value_networks[agent_key](o_tm1_trans[agent]) + q_t_value = self._target_value_networks[agent_key](o_t_trans[agent]) + q_t_selector = self._value_networks[agent_key](o_t_trans[agent]) - # Maybe transform the observation before feeding into policy and critic. - # Transforming the observations this way at the start of the learning - # step effectively means that the policy and critic share observation - # network weights. + # TODO Legal action masking + # q_t_selector = tf.where(o_t[agent].legal_actions, q_t_selector, -999999999) - o_tm1_feed, o_t_feed, a_tm1_feed = self._get_feed( - o_tm1, o_t, a_tm1, agent - ) + # pcont + discount = tf.cast(self._discount, dtype=d_t[agent].dtype) - if self._fingerprint: - f_tm1 = e_tm1["fingerprint"] - f_tm1 = tf.convert_to_tensor(f_tm1) - f_tm1 = tf.cast(f_tm1, "float32") - - f_t = e_t["fingerprint"] - f_t = tf.convert_to_tensor(f_t) - f_t = tf.cast(f_t, "float32") - - q_tm1 = self._q_networks[agent_key](o_tm1_feed, f_tm1) - q_t_value = self._target_q_networks[agent_key](o_t_feed, f_t) - q_t_selector = self._q_networks[agent_key](o_t_feed, f_t) - else: - q_tm1 = self._q_networks[agent_key](o_tm1_feed) - q_t_value = self._target_q_networks[agent_key](o_t_feed) - q_t_selector = self._q_networks[agent_key](o_t_feed) - - # Q-network learning - loss, loss_extras = trfl.double_qlearning( - q_tm1, - a_tm1_feed, - r_t[agent], - discount * d_t[agent], - q_t_value, - q_t_selector, + # Value loss. + value_loss, _ = trfl.double_qlearning( + q_tm1, a_tm1[agent], r_t[agent], discount * d_t[agent], q_t_value, q_t_selector ) - # Maybe do importance sampling. - if self._importance_sampling_exponent is not None: - importance_weights = 1.0 / sample_probs # [B] - importance_weights **= self._importance_sampling_exponent - importance_weights /= tf.reduce_max(importance_weights) - - # Reweight loss. - loss *= tf.cast(importance_weights, loss.dtype) # [B] - - # Update priorities. - errors = loss_extras.td_error - abs_errors = tf.abs(errors) - mean_priority = tf.reduce_mean(abs_errors, axis=0) - max_priority = tf.reduce_max(abs_errors, axis=0) - sample_priorities += ( - self._max_priority_weight * max_priority - + (1 - self._max_priority_weight) * mean_priority - ) - - loss = tf.reduce_mean(loss) - q_network_losses[agent] = {"policy_loss": loss} - - # Store losses and tape - self._q_network_losses = q_network_losses - self.tape = tape + self.value_losses[agent] = tf.reduce_mean(value_loss, axis=0) - # Store sample keys and priorities - self._sample_keys = sample_keys - self._sample_priorities = sample_priorities / len( - self._agents - ) # averaged over agents. + self.tape = tape + # Backward pass that calculates gradients and updates network. def _backward(self) -> None: """Trainer backward pass updating network parameters""" - q_network_losses = self._q_network_losses + # Calculate the gradients and update the networks + value_losses = self.value_losses tape = self.tape - for agent in self._agents: + for agent in self._trainer_agent_list: agent_key = self._agent_net_keys[agent] - # Get trainable variables - q_network_variables = self._q_networks[agent_key].trainable_variables + # Get trainable variables. + variables = ( + self._observation_networks[agent_key].trainable_variables + + self._value_networks[agent_key].trainable_variables + ) + - # Compute gradients - gradients = tape.gradient(q_network_losses[agent], q_network_variables) + # Compute gradients. + # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient + # on a persistent tape inside its context is significantly less efficient + # than calling it outside the context." caused by losses.dpg, which calls + # tape.gradient. + gradients = tape.gradient(value_losses[agent], variables) - # Clip gradients. - gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] + # Maybe clip gradients. + gradients = tf.clip_by_global_norm( + gradients, self._max_gradient_norm + )[0] # Apply gradients. - self._optimizers[agent_key].apply(gradients, q_network_variables) + self._optimizers[agent_key].apply(gradients, variables) train_utils.safe_del(self, "tape") - def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: - """get network variables - - Args: - names (Sequence[str]): network names - - Returns: - Dict[str, Dict[str, np.ndarray]]: network variables - """ + def step(self) -> None: + """trainer step to update the parameters of the agents in the system""" - variables: Dict[str, Dict[str, np.ndarray]] = {} - for network_type in names: - variables[network_type] = { - agent: tf2_utils.to_numpy( - self._system_network_variables[network_type][agent] - ) - for agent in self.unique_net_keys - } - return variables + raise NotImplementedError("A trainer statistics wrapper should overwrite this.") def after_trainer_step(self) -> None: """Optionally decay lr after every training step.""" @@ -506,7 +373,6 @@ def after_trainer_step(self) -> None: def _decay_lr(self, trainer_step: int) -> None: """Decay lr. - Args: trainer_step : trainer step time t. """ @@ -515,8 +381,9 @@ def _decay_lr(self, trainer_step: int) -> None: ) -class MADQNRecurrentTrainer(MADQNTrainer): +class MADQNRecurrentTrainer: """Recurrent MADQN trainer. + This is the trainer component of a MADQN system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -525,234 +392,249 @@ def __init__( self, agents: List[str], agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], - target_update_period: int, - dataset: tf.data.Dataset, + value_networks: Dict[str, snt.Module], + target_value_networks: Dict[str, snt.Module], optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], discount: float, + target_averaging: bool, + target_update_period: int, + target_update_rate: float, + dataset: tf.data.Dataset, + observation_networks: Dict[str, snt.Module], + target_observation_networks: Dict[str, snt.Module], + variable_client: VariableClient, + counts: Dict[str, Any], agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, max_gradient_norm: float = None, - counter: counting.Counter = None, logger: loggers.Logger = None, - fingerprint: bool = False, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - communication_module: Optional[BaseCommunicationModule] = None, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, + #bootstrap_n: int = 10, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise recurrent MADQN trainer - + """Initialise Recurrent MADDPG trainer Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - counter (counting.Counter, optional): step counter object. Defaults to None. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - communication_module (BaseCommunicationModule): module for communication - between agents. Defaults to None. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: + optimizer(s) for updating policy networks. + critic_optimizer: + optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. """ + #self._bootstrap_n = bootstrap_n - super().__init__( - agents=agents, - agent_types=agent_types, - q_networks=q_networks, - target_q_networks=target_q_networks, - target_update_period=target_update_period, - dataset=dataset, - optimizer=optimizer, - discount=discount, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - max_gradient_norm=max_gradient_norm, - counter=counter, - logger=logger, - fingerprint=fingerprint, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) + self._agents = agents + self._agent_type = agent_types + self._agent_net_keys = agent_net_keys + self._variable_client = variable_client + self._learning_rate_scheduler_fn = learning_rate_scheduler_fn - def _forward(self, inputs: Any) -> None: - """Trainer forward pass + # Setup counts + self._counts = counts - Args: - inputs (Any): input data from the data table (transitions) - """ + # Store online and target networks. + self._value_networks = value_networks + self._target_value_networks = target_value_networks + + # Ensure obs and target networks are sonnet modules + self._observation_networks = { + k: tf2_utils.to_sonnet_module(v) for k, v in observation_networks.items() + } + self._target_observation_networks = { + k: tf2_utils.to_sonnet_module(v) + for k, v in target_observation_networks.items() + } - data = tree.map_structure( - lambda v: tf.expand_dims(v, axis=0) if len(v.shape) <= 1 else v, inputs.data - ) - data = tf2_utils.batch_to_sequence(data) + # General learner book-keeping and loggers. + self._logger = logger or loggers.make_default_logger("trainer") - observations, actions, rewards, discounts, _, _ = ( - data.observations, - data.actions, - data.rewards, - data.discounts, - data.start_of_episode, - data.extras, - ) + # Other learner parameters. + self._discount = discount - # Using extra directly from inputs due to shape. - core_state = tree.map_structure( - lambda s: s[:, 0, :], inputs.data.extras["core_states"] - ) + # Set up gradient clipping. + if max_gradient_norm is not None: + self._max_gradient_norm = tf.convert_to_tensor(max_gradient_norm) + else: # A very large number. Infinity results in NaNs. + self._max_gradient_norm = tf.convert_to_tensor(1e10) - with tf.GradientTape(persistent=True) as tape: - q_network_losses: Dict[str, NestedArray] = {} + # Necessary to track when to update target networks. + self._num_steps = tf.Variable(0, dtype=tf.int32) + self._target_averaging = target_averaging + self._target_update_period = target_update_period + self._target_update_rate = target_update_rate - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - # Cast the additional discount to match the environment discount dtype. - discount = tf.cast(self._discount, dtype=discounts[agent][0].dtype) + # Create an iterator to go through the dataset. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - q, s = snt.static_unroll( - self._q_networks[agent_key], - observations[agent].observation, - core_state[agent][0], - ) + # Dictionary with unique network keys. + self.unique_net_keys = sort_str_num(self._value_networks.keys()) - q_targ, s = snt.static_unroll( - self._target_q_networks[agent_key], - observations[agent].observation, - core_state[agent][0], - ) + # Get the agents which shoud be updated and ran + self._trainer_agent_list = self._agents - q_network_losses[agent] = {"policy_loss": tf.zeros(())} - for t in range(1, q.shape[0]): - loss, _ = trfl.qlearning( - q[t - 1], - actions[agent][t - 1], - rewards[agent][t], - discount * discounts[agent][t], - q_targ[t], - ) + # Create optimizers for different agent types. + if not isinstance(optimizer, dict): + self._optimizers: Dict[str, snt.Optimizer] = {} + for agent in self.unique_net_keys: + self._optimizers[agent] = copy.deepcopy(optimizer) + else: + self._optimizers = optimizer - loss = tf.reduce_mean(loss) - q_network_losses[agent]["policy_loss"] += loss + # Expose the variables. + self._system_network_variables: Dict[str, Dict[str, snt.Module]] = { + "observations": {}, + "values": {}, + } + for agent_key in self.unique_net_keys: + self._system_network_variables["observations"][ + agent_key + ] = self._target_observation_networks[agent_key].variables + self._system_network_variables["values"][ + agent_key + ] = self._target_value_networks[agent_key].variables - self._q_network_losses = q_network_losses - self.tape = tape + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp: Optional[float] = None + def _transform_observations( + self, observations: Dict[str, mava_types.OLT] + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: + """apply the observation networks to the raw observations from the dataset + Args: + obs: raw agent observations + next_obs: raw next observations + Returns: + transformed + observations (features) + """ -class MADQNRecurrentCommTrainer(MADQNTrainer): - """Recurrent MADQN trainer with communication. - This is the trainer component of a MADQN system. IE it takes a dataset as input - and implements update functionality to learn from this dataset. - """ + # Note (dries): We are assuming that only the policy network + # is recurrent and not the observation network. + obs_trans = {} + obs_target_trans = {} + for agent in self._agents: + agent_key = self._agent_net_keys[agent] - def __init__( + reshaped_obs, dims = train_utils.combine_dim( + observations[agent].observation + ) + + obs_trans[agent] = train_utils.extract_dim( + self._observation_networks[agent_key](reshaped_obs), dims + ) + + obs_target_trans[agent] = train_utils.extract_dim( + self._target_observation_networks[agent_key](reshaped_obs), + dims, + ) + + # This stop_gradient prevents gradients to propagate into the target + # observation network. In addition, since the online policy network is + # evaluated at o_t, this also means the policy loss does not influence + # the observation network training. + obs_target_trans[agent] = tree.map_structure( + tf.stop_gradient, obs_target_trans[agent] + ) + return obs_trans, obs_target_trans + + def _update_target_networks(self) -> None: + """Update the target networks using either target averaging or + by directy copying the weights of the online networks every few steps.""" + for key in self.unique_net_keys: + # Update target network. + online_variables = ( + *self._observation_networks[key].variables, + *self._value_networks[key].variables, + ) + target_variables = ( + *self._target_observation_networks[key].variables, + *self._target_value_networks[key].variables, + ) + + if self._target_averaging: + assert 0.0 < self._target_update_rate < 1.0 + tau = self._target_update_rate + for src, dest in zip(online_variables, target_variables): + dest.assign(dest * (1.0 - tau) + src * tau) + else: + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + self._num_steps.assign_add(1) + + def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: + """Depricated""" + pass + + @tf.function + def _step( self, - agents: List[str], - agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], - target_update_period: int, - dataset: tf.data.Dataset, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], - discount: float, - agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, - communication_module: BaseCommunicationModule, - max_gradient_norm: float = None, - fingerprint: bool = False, - counter: counting.Counter = None, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - ): - """Initialise recurrent MADQN trainer with communication + ) -> Dict[str, Dict[str, Any]]: + """Trainer forward and backward passes. - Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - communication_module (BaseCommunicationModule): module for communication - between agents. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - counter (counting.Counter, optional): step counter object. Defaults to None. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. + Returns: + losses """ - super().__init__( - agents=agents, - agent_types=agent_types, - q_networks=q_networks, - target_q_networks=target_q_networks, - target_update_period=target_update_period, - dataset=dataset, - optimizer=optimizer, - discount=discount, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - max_gradient_norm=max_gradient_norm, - fingerprint=fingerprint, - counter=counter, - logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) + # Update the target networks + self._update_target_networks() - self._communication_module = communication_module + # Draw a batch of data from replay. + sample: reverb.ReplaySample = next(self._iterator) - def _forward(self, inputs: Any) -> None: - """Trainer forward pass + self._forward(sample) + + self._backward() + # Log losses per agent + return train_utils.map_losses_per_agent_value( + self.value_losses + ) + + # Forward pass that calculates loss. + def _forward(self, inputs: reverb.ReplaySample) -> None: + """Trainer forward pass Args: - inputs (Any): input data from the data table (transitions) + inputs: input data from the data table (transitions) """ - + # Convert to time major data = tree.map_structure( lambda v: tf.expand_dims(v, axis=0) if len(v.shape) <= 1 else v, inputs.data ) data = tf2_utils.batch_to_sequence(data) - observations, actions, rewards, discounts, _, _ = ( + print(data) + + # Note (dries): The unused variable is start_of_episodes. + observations, actions, rewards, discounts, _, extras = ( data.observations, data.actions, data.rewards, @@ -761,79 +643,132 @@ def _forward(self, inputs: Any) -> None: data.extras, ) - # Using extra directly from inputs due to shape. - core_state = tree.map_structure( - lambda s: s[:, 0, :], inputs.data.extras["core_states"] - ) - core_message = tree.map_structure( - lambda s: s[:, 0, :], inputs.data.extras["core_messages"] - ) - + # Get initial state for the LSTM from replay and + # extract the first state in the sequence. + # NOTE maybe indexing the wrong thing here?! + # core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) + # target_core_state = tree.map_structure(tf.identity, core_state) + core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) + target_core_state = tree.map_structure(lambda s: s[1, :, :], extras["core_states"]) + + # TODO (dries): Take out all the data_points that does not need + # to be processed here at the start. Therefore it does not have + # to be done later on and saves processing time. + + self.value_losses: Dict[str, tf.Tensor] = {} + + # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: - q_network_losses: Dict[str, NestedArray] = { - agent: {"policy_loss": tf.zeros(())} for agent in self._agents - } - - T = actions[self._agents[0]].shape[0] - - state = {agent: core_state[agent][0] for agent in self._agents} - target_state = {agent: core_state[agent][0] for agent in self._agents} - - message = {agent: core_message[agent][0] for agent in self._agents} - target_message = {agent: core_message[agent][0] for agent in self._agents} + # Note (dries): We are assuming that only the policy network + # is recurrent and not the observation network. + obs_trans, target_obs_trans = self._transform_observations(observations) - # _target_q_networks must be 1 step ahead - target_channel = self._communication_module.process_messages(target_message) for agent in self._agents: agent_key = self._agent_net_keys[agent] - (q_targ, m), s = self._target_q_networks[agent_key]( - observations[agent].observation[0], - target_state[agent], - target_channel[agent], + + # Double Q-learning + q, _ = snt.static_unroll( + self._value_networks[agent_key], obs_trans, core_state[agent][0] + ) + q_tm1 = q[:-1] + q_t_selector = q[1:] + q_t_value, _ = snt.static_unroll( + self._target_value_networks[agent_key], target_obs_trans[1:], target_core_state[agent][0] ) - target_state[agent] = s - target_message[agent] = m - for t in range(1, T, 1): - channel = self._communication_module.process_messages(message) - target_channel = self._communication_module.process_messages( - target_message + # TODO Legal action masking + # q_t_selector = tf.where(observations[agent].legal_actions, q_t_selector, -999999999) + + # Cast the additional discount to match + # the environment discount dtype. + discount = tf.cast(self._discount, dtype=discounts[agent].dtype) + + # Flatten out time and batch dim + q_tm1, dims = train_utils.combine_dim( + q_tm1 ) + q_t_selector, _ = train_utils.combine_dim( + q_t_selector + ) + q_t_value, _ = train_utils.combine_dim( + q_t_value + ) + a_tm1, _ = train_utils.combine_dim( + actions[agent][:-1] + ) + r_t, _ = train_utils.combine_dim( + rewards[agent][:-1] + ) + d_t, _ = train_utils.combine_dim( + discounts[agent][:-1] + ) + + # Value loss + value_loss, _ = trfl.double_qlearning(q_tm1, a_tm1, r_t, discount * d_t, q_t_value, q_t_selector) + + # TODO zero padding mask + + self.value_losses[agent] = tf.reduce_mean(value_loss, axis=0) - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - - # Cast the additional discount - # to match the environment discount dtype. - - discount = tf.cast(self._discount, dtype=discounts[agent][0].dtype) - - (q_targ, m), s = self._target_q_networks[agent_key]( - observations[agent].observation[t], - target_state[agent], - target_channel[agent], - ) - target_state[agent] = s - target_message[agent] = m - - (q, m), s = self._q_networks[agent_key]( - observations[agent].observation[t - 1], - state[agent], - channel[agent], - ) - state[agent] = s - message[agent] = m - - loss, _ = trfl.qlearning( - q, - actions[agent][t - 1], - rewards[agent][t - 1], - discount * discounts[agent][t], - q_targ, - ) - - loss = tf.reduce_mean(loss) - q_network_losses[agent]["policy_loss"] += loss - - self._q_network_losses = q_network_losses self.tape = tape + + # Backward pass that calculates gradients and updates network. + def _backward(self) -> None: + """Trainer backward pass updating network parameters""" + + # Calculate the gradients and update the networks + value_losses = self.value_losses + tape = self.tape + for agent in self._trainer_agent_list: + agent_key = self._agent_net_keys[agent] + + # Get trainable variables. + variables = ( + self._observation_networks[agent_key].trainable_variables + + self._value_networks[agent_key].trainable_variables + ) + + + # Compute gradients. + # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient + # on a persistent tape inside its context is significantly less efficient + # than calling it outside the context." caused by losses.dpg, which calls + # tape.gradient. + gradients = tape.gradient(value_losses[agent], variables) + + # Maybe clip gradients. + gradients = tf.clip_by_global_norm( + gradients, self._max_gradient_norm + )[0] + + # Apply gradients. + self._optimizers[agent_key].apply(gradients, variables) + + train_utils.safe_del(self, "tape") + + def step(self) -> None: + """trainer step to update the parameters of the agents in the system""" + + raise NotImplementedError("A trainer statistics wrapper should overwrite this.") + + def after_trainer_step(self) -> None: + """Optionally decay lr after every training step.""" + if self._learning_rate_scheduler_fn: + self._decay_lr(self._num_steps) + info: Dict[str, Dict[str, float]] = {} + for agent in self._agents: + info[agent] = {} + info[agent]["learning_rate"] = self._optimizers[ + self._agent_net_keys[agent] + ].learning_rate + if self._logger: + self._logger.write(info) + + def _decay_lr(self, trainer_step: int) -> None: + """Decay lr. + Args: + trainer_step : trainer step time t. + """ + train_utils.decay_lr( + self._learning_rate_scheduler_fn, self._optimizers, trainer_step + ) \ No newline at end of file diff --git a/mava/utils/training_utils.py b/mava/utils/training_utils.py index c62dfc5ad..4eb15bc33 100644 --- a/mava/utils/training_utils.py +++ b/mava/utils/training_utils.py @@ -125,6 +125,17 @@ def map_losses_per_agent_ac(critic_losses: Dict, policy_losses: Dict) -> Dict: return logged_losses +# Map value losses to dict, grouped by agent. +def map_losses_per_agent_value(value_losses: Dict) -> Dict: + assert len(value_losses) > 0 , "Invalid System Checkpointer." + logged_losses: Dict[str, Dict[str, Any]] = {} + for agent in value_losses.keys(): + logged_losses[agent] = { + "value_loss": value_losses[agent], + } + + return logged_losses + def combine_dim(inputs: Union[tf.Tensor, List, Tuple]) -> tf.Tensor: if isinstance(inputs, tf.Tensor): diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index a4eb4cfd0..fcb03c1d6 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -31,3 +31,5 @@ NetworkStatisticsMixing, ScaledDetailedTrainerStatistics, ) +from mava.wrappers.flatland import FlatlandEnvWrapper +from mava.wrappers.smac import SMACWrapper diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py new file mode 100644 index 000000000..e1dde7c18 --- /dev/null +++ b/mava/wrappers/smac.py @@ -0,0 +1,331 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps a PettingZoo MARL environment to be used as a dm_env environment.""" +from typing import Any, Dict, List, Optional, Union + +import dm_env +import numpy as np +from acme import specs + +from smac.env import StarCraft2Env + +from mava import types +from mava.utils.wrapper_utils import ( + convert_np_type, + parameterized_restart, +) +from mava.wrappers.env_wrappers import ParallelEnvWrapper + +class SMACWrapper(ParallelEnvWrapper): + """Environment wrapper for PettingZoo MARL environments.""" + + def __init__( + self, + environment: StarCraft2Env, + return_state_info: bool = False, + ): + """Constructor for parallel PZ wrapper. + + Args: + environment (ParallelEnv): parallel PZ env. + env_preprocess_wrappers (Optional[List], optional): Wrappers + that preprocess envs. + Format (env_preprocessor, dict_with_preprocessor_params). + """ + self._environment = environment + self._return_state_info = return_state_info + self._agents = [f"agent_{n}" for n in range(self._environment.n_agents)] + + self._reset_next_step = True + self._done = False + + def reset(self) -> dm_env.TimeStep: + """Resets the env. + + Returns: + dm_env.TimeStep: dm timestep. + """ + # Reset the environment + self._environment.reset() + self._done = False + + self._reset_next_step = False + self._step_type = dm_env.StepType.FIRST + + # Get observation from env + observation = self.environment.get_obs() + legal_actions = self._get_legal_actions() + observations = self._convert_observations(observation, legal_actions, self._done) + + # Set env discount to 1 for all agents + discount_spec = self.discount_spec() + self._discounts = { + agent: convert_np_type(discount_spec[agent].dtype, 1) + for agent in self._agents + } + + # Set reward to zero for all agents + rewards_spec = self.reward_spec() + rewards = { + agent: convert_np_type(rewards_spec[agent].dtype, 0) + for agent in self._agents + } + + # Possibly add state information to extras + if self._return_state_info: + state = self.get_state() + extras = {"s_t": state} + else: + extras = {} + + return parameterized_restart(rewards, self._discounts, observations), extras + + + def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: + """Steps in env. + + Args: + actions (Dict[str, np.ndarray]): actions per agent. + + Returns: + dm_env.TimeStep: dm timestep + """ + # Possibly reset the environment + if self._reset_next_step: + return self.reset() + + # Convert dict of actions to list for SMAC + actions = list(actions.values()) + + # Step the SMAC environment + reward, self._done, self._info = self._environment.step(actions) + + # Get the next observations + next_observations = self._environment.get_obs() + legal_actions = self._get_legal_actions() + next_observations = self._convert_observations(next_observations, legal_actions, self._done) + + # Convert team reward to agent-wise rewards + rewards = self._convert_reward(reward) + + # Possibly add state information to extras + if self._return_state_info: + state = self.get_state() + extras = {"s_t": state} + else: + extras = {} + + if self._done: + self._step_type = dm_env.StepType.LAST + self._reset_next_step = True + + # Discount on last timestep set to zero + self._discounts = { + agent: convert_np_type(self.discount_spec()[agent].dtype, 0.0) + for agent in self._agents + } + else: + self._step_type = dm_env.StepType.MID + + # Create timestep object + timestep = dm_env.TimeStep( + observation=next_observations, + reward=rewards, + discount=self._discounts, + step_type=self._step_type, + ) + + return timestep, extras + + def env_done(self) -> bool: + """Check if env is done. + + Returns: + bool: bool indicating if env is done. + """ + return self._done + + def _convert_reward(self, reward: float) -> Dict[str, float]: + """Convert rewards to be dm_env compatible. + + Args: + rewards: rewards per agent. + """ + rewards_spec = self.reward_spec() + rewards = {} + for agent in self._agents: + rewards[agent] = convert_np_type( + rewards_spec[agent].dtype, reward + ) + return rewards + + def _get_legal_actions(self): + legal_actions = [] + for i, _ in enumerate(self._agents): + legal_actions.append( + np.array(self._environment.get_avail_agent_actions(i), dtype='int') + ) + return legal_actions + + def _convert_observations( + self, observations: List, legal_actions: List, done: bool + ) -> types.Observation: + """Convert PettingZoo observation so it's dm_env compatible. + + Args: + observes (Dict[str, np.ndarray]): observations per agent. + dones (Dict[str, bool]): dones per agent. + + Returns: + types.Observation: dm compatible observations. + """ + olt_observations = {} + for i, agent in enumerate(self._agents): + + olt_observations[agent] = types.OLT( + observation=observations[i], + legal_actions=legal_actions[i], + terminal=np.asarray([done], dtype=np.float32), + ) + + return olt_observations + + def extra_spec(self) -> Dict[str, specs.BoundedArray]: + """Function returns extra spec (format) of the env. + + Returns: + Dict[str, specs.BoundedArray]: extra spec. + """ + if self._return_state_info: + return {"s_t": self._environment.get_state()} + else: + return {} + + def observation_spec(self) -> Dict[str, types.OLT]: + """Observation spec. + + Returns: + types.Observation: spec for environment. + """ + self._environment.reset() + + observations = self._environment.get_obs() + legal_actions = self._get_legal_actions() + + observation_specs = {} + for i, agent in enumerate(self._agents): + + observation_specs[agent] = types.OLT( + observation=observations[i], + legal_actions=legal_actions[i], + terminal=np.asarray([True], dtype=np.float32), + ) + + return observation_specs + + def action_spec(self) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: + """Action spec. + + Returns: + Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: spec for actions. + """ + action_specs = {} + for agent in self._agents: + action_specs[agent] = specs.DiscreteArray( + num_values=self._environment.n_actions, dtype=int + ) + return action_specs + + def reward_spec(self) -> Dict[str, specs.Array]: + """Reward spec. + + Returns: + Dict[str, specs.Array]: spec for rewards. + """ + reward_specs = {} + for agent in self._agents: + reward_specs[agent] = specs.Array((), np.float32) + return reward_specs + + def discount_spec(self) -> Dict[str, specs.BoundedArray]: + """Discount spec. + + Returns: + Dict[str, specs.BoundedArray]: spec for discounts. + """ + discount_specs = {} + for agent in self._agents: + discount_specs[agent] = specs.BoundedArray( + (), np.float32, minimum=0, maximum=1.0 + ) + return discount_specs + + def get_stats(self) -> Optional[Dict]: + """Return extra stats to be logged. + + Returns: + extra stats to be logged. + """ + + return {"win_rate": self._info["win_rate"]} + + @property + def agents(self) -> List: + """Agents still alive in env (not done). + + Returns: + List: alive agents in env. + """ + return self._agents + + @property + def possible_agents(self) -> List: + """All possible agents in env. + + Returns: + List: all possible agents in env. + """ + return self._agents + + @property + def environment(self) -> StarCraft2Env: + """Returns the wrapped environment. + + Returns: + ParallelEnv: parallel env. + """ + return self._environment + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment. + + Args: + name (str): attribute. + + Returns: + Any: return attribute from env or underlying env. + """ + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) + + + + +env = StarCraft2Env(map_name="3m") + +wrapped_env = SMACWrapper(env) \ No newline at end of file From f4539874d98fe43dc2b4f17be5d16555731744a4 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 19 Jan 2022 16:48:24 +0200 Subject: [PATCH 02/56] Fix Recurrent MADQN --- mava/systems/tf/madqn/builder.py | 2 ++ mava/systems/tf/madqn/execution.py | 2 +- mava/systems/tf/madqn/training.py | 11 ++++++----- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index 71747811b..1b5b79521 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -319,6 +319,8 @@ def make_adder( for table_key in self._config.table_network_config.keys() } + print() + # Select adder if issubclass(self._executor_fn, executors.FeedForwardExecutor): adder = reverb_adders.ParallelNStepTransitionAdder( diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index ff27b5883..444678f8a 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -282,7 +282,7 @@ def update(self, wait: bool = False) -> None: self._variable_client.get_async() -class MADQNRecurrentExecutor(executors.RecurrentExecutor, MADQNFeedForwardExecutor): +class MADQNRecurrentExecutor(executors.RecurrentExecutor, DQNExecutor): """A recurrent executor for MADQN. An executor based on a recurrent policy for each agent in the system. diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 3b4f46495..51e99dcd8 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -631,8 +631,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: ) data = tf2_utils.batch_to_sequence(data) - print(data) - # Note (dries): The unused variable is start_of_episodes. observations, actions, rewards, discounts, _, extras = ( data.observations, @@ -649,7 +647,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) # target_core_state = tree.map_structure(tf.identity, core_state) core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) - target_core_state = tree.map_structure(lambda s: s[1, :, :], extras["core_states"]) + target_core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) # TODO (dries): Take out all the data_points that does not need # to be processed here at the start. Therefore it does not have @@ -667,14 +665,17 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: agent_key = self._agent_net_keys[agent] # Double Q-learning + print(core_state[agent][0]) + print(obs_trans[agent].shape) q, _ = snt.static_unroll( - self._value_networks[agent_key], obs_trans, core_state[agent][0] + self._value_networks[agent_key], obs_trans[agent], core_state[agent][0] ) q_tm1 = q[:-1] q_t_selector = q[1:] q_t_value, _ = snt.static_unroll( - self._target_value_networks[agent_key], target_obs_trans[1:], target_core_state[agent][0] + self._target_value_networks[agent_key], target_obs_trans[agent], target_core_state[agent][0] ) + q_t_value = q_t_value[1:] # TODO Legal action masking # q_t_selector = tf.where(observations[agent].legal_actions, q_t_selector, -999999999) From 207f5f24921345458364404799b429cfcb527a11 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 20 Jan 2022 07:51:11 +0200 Subject: [PATCH 03/56] Working Rec MADQN on SMAC. --- .../smac/recurrent/decentralised/run_madqn.py | 73 ++++++------------- mava/systems/tf/madqn/training.py | 4 +- mava/wrappers/smac.py | 4 +- 3 files changed, 25 insertions(+), 56 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 8fa1d2bc1..c169caae9 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -27,18 +27,23 @@ from mava import specs as mava_specs from mava.components.tf import networks from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationTimestepScheduler, + LinearExplorationScheduler, ) from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy from mava.systems.tf import madqn from mava.utils import lp_utils from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils +from mava.utils.enums import ArchitectureType + +from smac.env import StarCraft2Env +from mava.wrappers import SMACWrapper + FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + "8m", "Starcraft 2 micromanagement map name (str).", ) @@ -49,62 +54,24 @@ ) flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") +def smac_env_factory(env_name="3m", evaluation = False): + env = StarCraft2Env(map_name=env_name) + env = SMACWrapper(env) -def custom_recurrent_network( - environment_spec: mava_specs.MAEnvironmentSpec, - agent_net_keys: Dict[str, str], - q_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = [128, 128], -) -> Mapping[str, types.TensorTransformation]: - """Creates networks used by the agents.""" - - specs = environment_spec.get_agent_specs() - - # Create agent_type specs - specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} - - if isinstance(q_networks_layer_sizes, Sequence): - q_networks_layer_sizes = {key: q_networks_layer_sizes for key in specs.keys()} - - q_networks = {} - action_selectors = {} - for key in specs.keys(): - - # Get total number of action dimensions from action spec. - num_dimensions = specs[key].actions.num_values - - # Create the policy network. - q_network = snt.DeepRNN( - [ - snt.Linear(q_networks_layer_sizes[key][0]), - tf.nn.relu, - snt.GRU(q_networks_layer_sizes[key][1]), - networks.NearZeroInitializedLinear(num_dimensions), - ] - ) - - # epsilon greedy action selector - action_selector = EpsilonGreedy - - q_networks[key] = q_network - action_selectors[key] = action_selector - - return { - "q_networks": q_networks, - "action_selectors": action_selectors, - } + return env def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name + smac_env_factory, env_name=FLAGS.map_name ) # Networks. network_factory = lp_utils.partial_kwargs( - custom_recurrent_network, - q_networks_layer_sizes=[128, 128], + madqn.make_default_networks, + archecture_type=ArchitectureType.recurrent ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -127,17 +94,21 @@ def main(_: Any) -> None: network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 ), checkpoint_subpath=checkpoint_dir, batch_size=32, - executor_variable_update_period=100, + executor_variable_update_period=200, target_update_period=200, - max_gradient_norm=10.0, + max_gradient_norm=20.0, + sequence_length=60, + period=60, + min_replay_size=100, + max_replay_size=4000, trainer_fn=madqn.training.MADQNRecurrentTrainer, executor_fn=madqn.execution.MADQNRecurrentExecutor, ).build() diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 51e99dcd8..ad8933f69 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -665,8 +665,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: agent_key = self._agent_net_keys[agent] # Double Q-learning - print(core_state[agent][0]) - print(obs_trans[agent].shape) q, _ = snt.static_unroll( self._value_networks[agent_key], obs_trans[agent], core_state[agent][0] ) @@ -678,7 +676,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: q_t_value = q_t_value[1:] # TODO Legal action masking - # q_t_selector = tf.where(observations[agent].legal_actions, q_t_selector, -999999999) + q_t_selector = tf.where(tf.cast(observations[agent].legal_actions[1:], 'bool'), q_t_selector, -999999999) # Cast the additional discount to match # the environment discount dtype. diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index e1dde7c18..9dc19f6f2 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -279,8 +279,8 @@ def get_stats(self) -> Optional[Dict]: Returns: extra stats to be logged. """ - - return {"win_rate": self._info["win_rate"]} + pass + # return {"win_rate": self._info["win_rate"]} @property def agents(self) -> List: From 6d62a0b1ddf21564dd9584dbef768a8680b8fc35 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 20 Jan 2022 10:31:45 +0200 Subject: [PATCH 04/56] Start Env Wrappers. --- .../smac/recurrent/decentralised/run_madqn.py | 2 +- mava/systems/tf/madqn/execution.py | 6 ---- mava/systems/tf/madqn/training.py | 19 +++++------ mava/wrappers/env_preprocess_wrappers.py | 33 +++++++++++++++++++ 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index c169caae9..3e9a632d1 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -43,7 +43,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "8m", + "corridor", "Starcraft 2 micromanagement map name (str).", ) diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 444678f8a..2426eec4d 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -394,12 +394,6 @@ def select_action( action and policy. """ - # Initialize the RNN state if necessary. - if self._states[agent] is None: - # index network either on agent type or on agent id - agent_key = self._agent_net_keys[agent] - self._states[agent] = self._value_networks[agent_key].initia_state(1) - # Step the recurrent policy forward given the current observation and state. action, new_state = self._policy( agent, observation.observation, observation.legal_actions, self._states[agent] diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index ad8933f69..1af377ea0 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -643,9 +643,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Get initial state for the LSTM from replay and # extract the first state in the sequence. - # NOTE maybe indexing the wrong thing here?! - # core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) - # target_core_state = tree.map_structure(tf.identity, core_state) core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) target_core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) @@ -668,14 +665,14 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: q, _ = snt.static_unroll( self._value_networks[agent_key], obs_trans[agent], core_state[agent][0] ) - q_tm1 = q[:-1] - q_t_selector = q[1:] + q_tm1 = q[:-1] # Chop off last timestep + q_t_selector = q[1:] # Chop off first timestep q_t_value, _ = snt.static_unroll( self._target_value_networks[agent_key], target_obs_trans[agent], target_core_state[agent][0] ) - q_t_value = q_t_value[1:] + q_t_value = q_t_value[1:] # Chop off first timestep - # TODO Legal action masking + # Legal action masking q_t_selector = tf.where(tf.cast(observations[agent].legal_actions[1:], 'bool'), q_t_selector, -999999999) # Cast the additional discount to match @@ -683,7 +680,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: discount = tf.cast(self._discount, dtype=discounts[agent].dtype) # Flatten out time and batch dim - q_tm1, dims = train_utils.combine_dim( + q_tm1, _ = train_utils.combine_dim( q_tm1 ) q_t_selector, _ = train_utils.combine_dim( @@ -693,13 +690,13 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: q_t_value ) a_tm1, _ = train_utils.combine_dim( - actions[agent][:-1] + actions[agent][:-1] # Chop off last timestep ) r_t, _ = train_utils.combine_dim( - rewards[agent][:-1] + rewards[agent][:-1] # Chop off last timestep ) d_t, _ = train_utils.combine_dim( - discounts[agent][:-1] + discounts[agent][:-1] # Chop off last timestep ) # Value loss diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index e68db7da9..eb2994cba 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -360,3 +360,36 @@ def _modify_observation(self, observation: Observation) -> Observation: def _modify_action(self, action: Action) -> Action: return action + +class ConcatAgentIdToObservation: + """Concat one-hot vector of agent ID to obs. + + We assume the environment has an ordered list + self.possible_agents. + """ + + def __init__(self): + pass + + def reset(self): + pass + + def step(self, actions: Dict) -> Any: + pass + +class ConcatPrevActionToObservation: + """Concat one-hot vector of agent prev_action to obs. + + We assume the environment has discreet actions. + + TODO support continuous actions. + """ + + def __init__(self): + pass + + def reset(self): + pass + + def step(self, actions: Dict) -> Any: + pass From 78408cc87a6c0685f66e78b73da042f916a89d03 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 20 Jan 2022 14:09:28 +0200 Subject: [PATCH 05/56] Agent ID wrapper. --- .../recurrent/decentralised/run_madqn.py | 127 ++++++ .../smac/feedforward/decentralised/test.py | 3 + .../smac/recurrent/decentralised/run_madqn.py | 4 +- .../smac/recurrent/decentralised/run_qmix.py | 124 ++++++ .../smac/recurrent/decentralised/run_vdn.py | 122 ++++++ mava/systems/tf/madqn/execution.py | 3 - mava/systems/tf/madqn/networks.py | 10 +- mava/systems/tf/madqn/training.py | 10 +- .../tf/value_decomposition/__init__.py | 21 + mava/systems/tf/value_decomposition/mixer.py | 119 ++++++ .../tf/value_decomposition/networks.py | 120 ++++++ mava/systems/tf/value_decomposition/system.py | 276 +++++++++++++ .../tf/value_decomposition/training.py | 362 ++++++++++++++++++ mava/wrappers/env_preprocess_wrappers.py | 89 ++++- mava/wrappers/smac.py | 5 +- 15 files changed, 1368 insertions(+), 27 deletions(-) create mode 100644 examples/flatland/recurrent/decentralised/run_madqn.py create mode 100644 examples/smac/recurrent/decentralised/run_qmix.py create mode 100644 examples/smac/recurrent/decentralised/run_vdn.py create mode 100644 mava/systems/tf/value_decomposition/__init__.py create mode 100644 mava/systems/tf/value_decomposition/mixer.py create mode 100644 mava/systems/tf/value_decomposition/networks.py create mode 100644 mava/systems/tf/value_decomposition/system.py create mode 100644 mava/systems/tf/value_decomposition/training.py diff --git a/examples/flatland/recurrent/decentralised/run_madqn.py b/examples/flatland/recurrent/decentralised/run_madqn.py new file mode 100644 index 000000000..0dca5c70b --- /dev/null +++ b/examples/flatland/recurrent/decentralised/run_madqn.py @@ -0,0 +1,127 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +from datetime import datetime +from typing import Any, Dict + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.components.tf.modules.exploration.exploration_scheduling import ( + LinearExplorationScheduler, +) +from mava.systems.tf import madqn +from mava.utils import lp_utils +from mava.utils.environments.flatland_utils import flatland_env_factory +from mava.utils.loggers import logger_utils +from mava.utils.enums import ArchitectureType + + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + + +# flatland environment config +flatland_env_config: Dict = { + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, +} + + +def main(_: Any) -> None: + + # Environment. + environment_factory = functools.partial( + flatland_env_factory, env_config=flatland_env_config, include_agent_info=False + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + madqn.make_default_networks, + architecture_type=ArchitectureType.recurrent + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # distributed program + program = madqn.MADQN( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=1, + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + ), + optimizer=snt.optimizers.Adam(learning_rate=1e-4), + batch_size=32, + executor_variable_update_period=200, + target_update_period=200, + max_gradient_norm=20.0, + sequence_length=20, + period=10, + min_replay_size=100, + max_replay_size=5000, + trainer_fn=madqn.training.MADQNRecurrentTrainer, + executor_fn=madqn.execution.MADQNRecurrentExecutor, + checkpoint_subpath=checkpoint_dir, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + # Launch. + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/smac/feedforward/decentralised/test.py b/examples/smac/feedforward/decentralised/test.py index 909f7a37f..6cf7c8aa7 100644 --- a/examples/smac/feedforward/decentralised/test.py +++ b/examples/smac/feedforward/decentralised/test.py @@ -1,5 +1,6 @@ from smac.env import StarCraft2Env from mava.wrappers import SMACWrapper +from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation import numpy as np @@ -7,6 +8,8 @@ env = SMACWrapper(env) +env = ConcatAgentIdToObservation(env) + spec = env.action_spec() spec = env.observation_spec() diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 3e9a632d1..b2391c24f 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -43,7 +43,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "corridor", + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -71,7 +71,7 @@ def main(_: Any) -> None: # Networks. network_factory = lp_utils.partial_kwargs( madqn.make_default_networks, - archecture_type=ArchitectureType.recurrent + architecture_type=ArchitectureType.recurrent ) # Checkpointer appends "Checkpoints" to checkpoint_dir diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/decentralised/run_qmix.py new file mode 100644 index 000000000..fe0b8b3fa --- /dev/null +++ b/examples/smac/recurrent/decentralised/run_qmix.py @@ -0,0 +1,124 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.components.tf.modules.exploration.exploration_scheduling import ( + LinearExplorationScheduler, +) +from mava.systems.tf import value_decomposition +from mava.utils import lp_utils +from mava.utils.loggers import logger_utils + +from mava.systems.tf.value_decomposition.mixer import QMIX + +from smac.env import StarCraft2Env +from mava.wrappers import SMACWrapper + + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "map_name", + "3m", + "Starcraft 2 micromanagement map name (str).", +) + +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + +def smac_env_factory(env_name="3m", evaluation = False): + env = StarCraft2Env(map_name=env_name) + env = SMACWrapper(env) + + return env + + +def main(_: Any) -> None: + """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" + # environment + environment_factory = functools.partial( + smac_env_factory, env_name=FLAGS.map_name + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + value_decomposition.make_default_networks, + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + num_agents = len(environment_factory().possible_agents) + + # distributed program + program = value_decomposition.ValueDecomposition( + environment_factory=environment_factory, + network_factory=network_factory, + mixer=QMIX(num_agents=num_agents), + logger_factory=logger_factory, + num_executors=1, + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=3e-5 + ), + optimizer=snt.optimizers.RMSProp( + learning_rate=0.0005, epsilon=0.00001, decay=0.99 + ), + checkpoint_subpath=checkpoint_dir, + batch_size=32, + executor_variable_update_period=200, + target_update_period=200, + max_gradient_norm=20.0, + sequence_length=60, + period=60, + min_replay_size=100, + max_replay_size=5000, + ).build() + + # launch + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/smac/recurrent/decentralised/run_vdn.py b/examples/smac/recurrent/decentralised/run_vdn.py new file mode 100644 index 000000000..7c486533d --- /dev/null +++ b/examples/smac/recurrent/decentralised/run_vdn.py @@ -0,0 +1,122 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.components.tf.modules.exploration.exploration_scheduling import ( + LinearExplorationScheduler, +) +from mava.systems.tf import value_decomposition +from mava.utils import lp_utils +from mava.utils.loggers import logger_utils + +from mava.systems.tf.value_decomposition.mixer import VDN + +from smac.env import StarCraft2Env +from mava.wrappers import SMACWrapper + + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "map_name", + "3m", + "Starcraft 2 micromanagement map name (str).", +) + +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + +def smac_env_factory(env_name="3m", evaluation = False): + env = StarCraft2Env(map_name=env_name) + env = SMACWrapper(env) + + return env + + +def main(_: Any) -> None: + """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" + # environment + environment_factory = functools.partial( + smac_env_factory, env_name=FLAGS.map_name + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + value_decomposition.make_default_networks, + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # distributed program + program = value_decomposition.ValueDecomposition( + environment_factory=environment_factory, + network_factory=network_factory, + mixer=VDN(), + logger_factory=logger_factory, + num_executors=1, + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=6e-5 + ), + optimizer=snt.optimizers.RMSProp( + learning_rate=0.0005, epsilon=0.00001, decay=0.99 + ), + checkpoint_subpath=checkpoint_dir, + batch_size=32, + executor_variable_update_period=200, + target_update_period=200, + max_gradient_norm=20.0, + sequence_length=60, + period=60, + min_replay_size=100, + max_replay_size=5000, + ).build() + + # launch + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 2426eec4d..a4533d9dc 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -417,7 +417,6 @@ def select_actions( Returns: actions and policies for all agents in the system. """ - actions = {} for agent, observation in observations.items(): actions[agent] = self.select_action(agent, observation) @@ -436,7 +435,6 @@ def observe_first( extras: possible extra information to record during the first step. """ - # Re-initialize the RNN state. for agent, _ in timestep.observation.items(): # index network either on agent type or on agent id @@ -477,7 +475,6 @@ def observe( next_extras: possible extra information to record during the transition. """ - if not self._adder: return diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index 50f5fea04..4a6e7a516 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -35,7 +35,7 @@ def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, agent_net_keys: Dict[str, str], value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, - archecture_type: ArchitectureType = ArchitectureType.feedforward, + architecture_type: ArchitectureType = ArchitectureType.feedforward, seed: Optional[int] = None, ) -> Mapping[str, types.TensorTransformation]: """Default networks for maddpg. @@ -64,7 +64,7 @@ def make_default_networks( """ # Set Policy function and layer size # Default size per arch type. - if archecture_type == ArchitectureType.feedforward: + if architecture_type == ArchitectureType.feedforward: if not value_networks_layer_sizes: value_networks_layer_sizes = ( 256, @@ -72,7 +72,7 @@ def make_default_networks( 256, ) value_network_func = snt.Sequential - elif archecture_type == ArchitectureType.recurrent: + elif architecture_type == ArchitectureType.recurrent: if not value_networks_layer_sizes: value_networks_layer_sizes = (128, 64) value_network_func = snt.DeepRNN @@ -105,13 +105,13 @@ def make_default_networks( observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. - if archecture_type == ArchitectureType.feedforward: + if architecture_type == ArchitectureType.feedforward: value_network = [ networks.LayerNormMLP( value_networks_layer_sizes[key], activate_final=True, seed=seed ), ] - elif archecture_type == ArchitectureType.recurrent: + elif architecture_type == ArchitectureType.recurrent: value_network = [ networks.LayerNormMLP( value_networks_layer_sizes[key][:-1], diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 1af377ea0..4046d64b8 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -407,7 +407,6 @@ def __init__( agent_net_keys: Dict[str, str], max_gradient_norm: float = None, logger: loggers.Logger = None, - #bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise Recurrent MADDPG trainer @@ -594,7 +593,7 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] """Depricated""" pass - @tf.function + # @tf.function def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -603,10 +602,6 @@ def _step( Returns: losses """ - - # Update the target networks - self._update_target_networks() - # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) @@ -614,6 +609,9 @@ def _step( self._backward() + # Update the target networks + self._update_target_networks() + # Log losses per agent return train_utils.map_losses_per_agent_value( self.value_losses diff --git a/mava/systems/tf/value_decomposition/__init__.py b/mava/systems/tf/value_decomposition/__init__.py new file mode 100644 index 000000000..4b4fdbbfa --- /dev/null +++ b/mava/systems/tf/value_decomposition/__init__.py @@ -0,0 +1,21 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of a Value Decompostion agent.""" +from mava.systems.tf.value_decomposition.networks import make_default_networks +from mava.systems.tf.value_decomposition.system import ValueDecomposition +from mava.systems.tf.value_decomposition.training import ( + ValueDecompositionRecurrentTrainer, +) diff --git a/mava/systems/tf/value_decomposition/mixer.py b/mava/systems/tf/value_decomposition/mixer.py new file mode 100644 index 000000000..6565c7944 --- /dev/null +++ b/mava/systems/tf/value_decomposition/mixer.py @@ -0,0 +1,119 @@ +import sonnet as snt +import tensorflow as tf + +@snt.allow_empty_variables +class BaseMixer(snt.Module): + """Base mixing class. + + Base mixer should take in agent q-values and environment global state tensors. + """ + def __init__(self): + super().__init__() + + def __call__(self, agent_qs: tf.Tensor , states: tf.Tensor): + return agent_qs + + """Initialize Base Mixer class + Args: + agent_qs: Tensor containing the q-values of actions chosen by agents + states: Tensor containing global environment state. + """ + +@snt.allow_empty_variables +class VDN(BaseMixer): + """VDN mixing network.""" + + def __init__(self): + super().__init__() + + def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor): + return tf.reduce_sum(agent_qs, axis=-1, keepdims=True) + + """Initialize VDN class + Args: + agent_qs: Tensor containing the q-values of actions chosen by agents + states: Tensor containing global environment state. + Returns: + Tensor with total q-value. + """ + +class QMIX(BaseMixer): + """QMIX mixing network.""" + + def __init__( + self, + num_agents: int, + embed_dim: int = 32, + hypernet_embed: int = 64 + ): + """Inialize QMIX mixing network + + Args: + num_agents: Number of agents in the enviroment + state_dim: Dimensions of the global environment state + embed_dim: TODO (Ruan): Cluade please add + hypernet_embed: TODO (Ruan): Claude Please add + """ + + super().__init__() + self.num_agents = num_agents + self.embed_dim = embed_dim + self.hypernet_embed = hypernet_embed + + + self.hyper_w_1 = snt.Sequential( + [ + snt.Linear(self.hypernet_embed), + tf.nn.relu, + snt.Linear(self.embed_dim * self.num_agents) + ] + ) + + self.hyper_w_final = snt.Sequential( + [ + snt.Linear(self.hypernet_embed), + tf.nn.relu, + snt.Linear(self.embed_dim) + ] + ) + + # State dependent bias for hidden layer + self.hyper_b_1 = snt.Linear(self.embed_dim) + + # V(s) instead of a bias for the last layers + self.V = snt.Sequential( + [ + snt.Linear(self.embed_dim), + tf.nn.relu, + snt.Linear(1) + ] + ) + + def __call__(self, agent_qs, states): + bs = agent_qs.shape[1] + state_dim = states.shape[-1] + + agent_qs = tf.reshape(agent_qs, (-1, 1, self.num_agents)) + states = tf.reshape(states, (-1, state_dim)) + + # First layer + w1 = tf.abs(self.hyper_w_1(states)) + b1 = self.hyper_b_1(states) + w1 = tf.reshape(w1, (-1, self.num_agents, self.embed_dim)) + b1 = tf.reshape(b1, (-1, 1, self.embed_dim)) + hidden = tf.nn.elu(tf.matmul(agent_qs, w1) + b1) + + # Second layer + w_final = tf.abs(self.hyper_w_final(states)) + w_final = tf.reshape(w_final, (-1, self.embed_dim, 1)) + + # State-dependent bias + v = tf.reshape(self.V(states), (-1, 1, 1)) + + # Compute final output + y = tf.matmul(hidden, w_final) + v + + # Reshape and return + q_tot = tf.reshape(y, (-1, bs, 1)) # [T, B, 1] + + return q_tot \ No newline at end of file diff --git a/mava/systems/tf/value_decomposition/networks.py b/mava/systems/tf/value_decomposition/networks.py new file mode 100644 index 000000000..6250b1c59 --- /dev/null +++ b/mava/systems/tf/value_decomposition/networks.py @@ -0,0 +1,120 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Mapping, Optional, Sequence, Union + +import numpy as np +import sonnet as snt +import tensorflow as tf +from acme import types +from acme.tf import utils as tf2_utils +from dm_env import specs + +from mava import specs as mava_specs +from mava.components.tf import networks +from mava.utils.enums import ArchitectureType +from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy + +Array = specs.Array +BoundedArray = specs.BoundedArray +DiscreteArray = specs.DiscreteArray + + +def make_default_networks( + environment_spec: mava_specs.MAEnvironmentSpec, + agent_net_keys: Dict[str, str], + value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, + seed: Optional[int] = None, +) -> Mapping[str, types.TensorTransformation]: + """Default networks for maddpg. + + Args: + environment_spec: description of the action and + observation spaces etc. for each agent in the system. + agent_net_keys: specifies what network each agent uses. + vmin: hyperparameters for the distributional critic in mad4pg. + vmax: hyperparameters for the distributional critic in mad4pg. + net_spec_keys: specifies the specs of each network. + policy_networks_layer_sizes: size of policy networks. + critic_networks_layer_sizes: size of critic networks. + sigma: hyperparameters used to add Gaussian noise + for simple exploration. Defaults to 0.3. + archecture_type: archecture used + for agent networks. Can be feedforward or recurrent. + Defaults to ArchitectureType.feedforward. + + num_atoms: hyperparameters for the distributional critic in + mad4pg. + seed: random seed for network initialization. + + Returns: + returned agent networks. + """ + + if not value_networks_layer_sizes: + value_networks_layer_sizes = (128, 64) + + value_network_func = snt.DeepRNN + + assert value_networks_layer_sizes is not None + assert value_network_func is not None + + specs = environment_spec.get_agent_specs() + + # Create agent_type specs + specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} + + + if isinstance(value_networks_layer_sizes, Sequence): + value_networks_layer_sizes = { + key: value_networks_layer_sizes for key in specs.keys() + } + if isinstance(value_networks_layer_sizes, Sequence): + value_networks_layer_sizes = { + key: value_networks_layer_sizes for key in specs.keys() + } + + observation_networks = {} + value_networks = {} + action_selectors = {} + for key, spec in specs.items(): + num_actions = spec.actions.num_values + + # An optional network to process observations + observation_network = tf2_utils.to_sonnet_module(tf.identity) + + value_network = [ + networks.LayerNormMLP( + value_networks_layer_sizes[key][:-1], + activate_final=True, + seed=seed, + ), + snt.GRU(value_networks_layer_sizes[key][-1]), + ] + + value_network += [ + networks.NearZeroInitializedLinear(num_actions, seed=seed), + ] + + value_network = value_network_func(value_network) + + observation_networks[key] = observation_network + value_networks[key] = value_network + action_selectors[key] = EpsilonGreedy + + return { + "values": value_networks, + "action_selectors": action_selectors, + "observations": observation_networks, + } diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py new file mode 100644 index 000000000..929a8a608 --- /dev/null +++ b/mava/systems/tf/value_decomposition/system.py @@ -0,0 +1,276 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Value Decomposition system implementation.""" + +from typing import Callable, Dict, List, Optional, Type, Union, Mapping + +import dm_env +import reverb +import sonnet as snt +from acme import specs as acme_specs + +import mava +from mava import specs as mava_specs +from mava.components.tf.architectures import ( + DecentralisedValueActor, +) +from mava.types import EpsilonScheduler +from mava.environment_loop import ParallelEnvironmentLoop +from mava.systems.tf.madqn.execution import MADQNRecurrentExecutor +from mava.systems.tf.value_decomposition.training import ValueDecompositionRecurrentTrainer +from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.utils import enums +from mava.utils.loggers import MavaLogger +from mava.systems.tf.madqn import MADQN + + +class ValueDecomposition(MADQN): + """Value Decomposition systems.""" + + """TODO: Implement faster adders to speed up training times when + using multiple trainers with non-shared weights.""" + + def __init__( # noqa + self, + environment_factory: Callable[[bool], dm_env.Environment], + network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], + mixer: snt.Module, + exploration_scheduler_fn: Union[ + EpsilonScheduler, + Mapping[str, EpsilonScheduler], + Mapping[str, Mapping[str, EpsilonScheduler]], + ], + logger_factory: Callable[[str], MavaLogger] = None, + architecture: Type[ + DecentralisedValueActor + ] = DecentralisedValueActor, + trainer_fn: Type[ValueDecompositionRecurrentTrainer] = ValueDecompositionRecurrentTrainer, + executor_fn: Type[MADQNRecurrentExecutor] = MADQNRecurrentExecutor, + num_executors: int = 1, + trainer_networks: Union[ + Dict[str, List], enums.Trainer + ] = enums.Trainer.single_trainer, + network_sampling_setup: Union[ + List, enums.NetworkSampler + ] = enums.NetworkSampler.fixed_agent_networks, + shared_weights: bool = True, + environment_spec: mava_specs.MAEnvironmentSpec = None, + discount: float = 0.99, + batch_size: int = 32, + prefetch_size: int = 4, + target_averaging: bool = False, + target_update_period: int = 200, + target_update_rate: Optional[float] = None, + executor_variable_update_period: int = 200, + min_replay_size: int = 100, + max_replay_size: int = 5000, + samples_per_insert: Optional[float] = 32.0, + optimizer: Union[ + snt.Optimizer, Dict[str, snt.Optimizer] + ] = snt.optimizers.Adam(learning_rate=1e-4), + mixer_optimizer: snt.Optimizer = snt.optimizers.Adam(learning_rate=1e-4), + sequence_length: int = 20, + period: int = 10, + max_gradient_norm: float = None, + checkpoint: bool = True, + checkpoint_subpath: str = "~/mava/", + checkpoint_minute_interval: int = 5, + logger_config: Dict = {}, + train_loop_fn: Callable = ParallelEnvironmentLoop, + eval_loop_fn: Callable = ParallelEnvironmentLoop, + train_loop_fn_kwargs: Dict = {}, + eval_loop_fn_kwargs: Dict = {}, + termination_condition: Optional[Dict[str, int]] = None, + evaluator_interval: Optional[dict] = None, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + ): + """Initialise the system + Args: + environment_factory: function to + instantiate an environment. + network_factory: function to instantiate system networks. + logger_factory: function to + instantiate a system logger. + architecture: + system architecture, e.g. decentralised or centralised. + trainer_fn: training type + associated with executor and architecture, e.g. centralised training. + executor_fn: executor type, e.g. + feedforward or recurrent. + num_executors: number of executor processes to run in + parallel.. + environment_spec: description of + the action, observation spaces etc. for each agent in the system. + trainer_networks: networks each + trainer trains on. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + enums.NetworkSampler settings: + fixed_agent_networks: Keeps the networks + used by each agent fixed throughout training. + random_agent_networks: Creates N network policies, where N is the + number of agents. Randomly select policies from this sets for each + agent at the start of a episode. This sampling is done with + replacement so the same policy can be selected for more than one + agent for a given episode. + Custom list: Alternatively one can specify a custom nested list, + with network keys in, that will be used by the executors at + the start of each episode to sample networks for each agent. + shared_weights: whether agents should share weights or not. + When network_sampling_setup are provided the value of shared_weights is + ignored. + discount: discount factor to use for TD updates. + batch_size: sample batch size for updates. + prefetch_size: size to prefetch from replay. + target_averaging: whether to use polyak averaging for + target network updates. + target_update_period: number of steps before target + networks are updated. + target_update_rate: update rate when using + averaging. + executor_variable_update_period: number of steps before + updating executor variables from the variable source. + min_replay_size: minimum replay size before updating. + max_replay_size: maximum replay size. + samples_per_insert: number of samples to take + from replay for every insert that is made. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic + networks. + n_step: number of steps to include prior to boostrapping. + sequence_length: recurrent sequence rollout length. + period: Consecutive starting points for overlapping + rollouts across a sequence. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + checkpoint: whether to checkpoint models. + checkpoint_minute_interval: The number of minutes to wait between + checkpoints. + checkpoint_subpath: subdirectory specifying where to store + checkpoints. + logger_config: additional configuration settings for the + logger factory. + train_loop_fn: function to instantiate a train loop. + eval_loop_fn: function to instantiate an evaluation + loop. + train_loop_fn_kwargs: possible keyword arguments to send + to the training loop. + eval_loop_fn_kwargs: possible keyword arguments to send to + the evaluation loop. + termination_condition: An optional terminal condition can be + provided that stops the program once the condition is + satisfied. Available options include specifying maximum + values for trainer_steps, trainer_walltime, evaluator_steps, + evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. + learning_rate_scheduler_fn: dict with two functions/classes (one for the + policy and one for the critic optimizer), that takes in a trainer + step t and returns the current learning rate, + e.g. {"policy": policy_lr_schedule ,"critic": critic_lr_schedule}. + See + examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py + for an example. + evaluator_interval: An optional condition that is used to + evaluate/test system performance after [evaluator_interval] + condition has been met. If None, evaluation will + happen at every timestep. + E.g. to evaluate a system after every 100 executor episodes, + evaluator_interval = {"executor_episodes": 100}. + """ + super().__init__( + environment_factory=environment_factory, + network_factory=network_factory, + exploration_scheduler_fn=exploration_scheduler_fn, + logger_factory=logger_factory, + architecture=architecture, + trainer_fn=trainer_fn, + executor_fn=executor_fn, + num_executors=num_executors, + trainer_networks=trainer_networks, + network_sampling_setup=network_sampling_setup, + shared_weights=shared_weights, + environment_spec=environment_spec, + discount=discount, + batch_size=batch_size, + prefetch_size=prefetch_size, + target_averaging=target_averaging, + target_update_period=target_update_period, + target_update_rate=target_update_rate, + executor_variable_update_period=executor_variable_update_period, + min_replay_size=min_replay_size, + max_replay_size=max_replay_size, + samples_per_insert=samples_per_insert, + optimizer=optimizer, + sequence_length=sequence_length, + period=period, + max_gradient_norm=max_gradient_norm, + checkpoint=checkpoint, + checkpoint_subpath=checkpoint_subpath, + checkpoint_minute_interval=checkpoint_minute_interval, + logger_config=logger_config, + train_loop_fn=train_loop_fn, + eval_loop_fn=eval_loop_fn, + train_loop_fn_kwargs=train_loop_fn_kwargs, + eval_loop_fn_kwargs=eval_loop_fn_kwargs, + termination_condition=termination_condition, + evaluator_interval=evaluator_interval, + learning_rate_scheduler_fn=learning_rate_scheduler_fn, + ) + + self._mixer = mixer + self._mixer_optimizer = mixer_optimizer + + def trainer( + self, + trainer_id: str, + replay: reverb.Client, + variable_source: MavaVariableSource, + ) -> mava.core.Trainer: + """System trainer + Args: + trainer_id: Id of the trainer being created. + replay: replay data table to pull data from. + variable_source: variable server for updating + network variables. + Returns: + system trainer. + """ + + # create logger + trainer_logger_config = {} + if self._logger_config and "trainer" in self._logger_config: + trainer_logger_config = self._logger_config["trainer"] + trainer_logger = self._logger_factory( # type: ignore + trainer_id, **trainer_logger_config + ) + + # Create the system + networks = self.create_system() + + dataset = self._builder.make_dataset_iterator(replay, trainer_id) + + trainer = self._builder.make_trainer( + networks=networks, + trainer_networks=self._trainer_networks[trainer_id], + trainer_table_entry=self._table_network_config[trainer_id], + dataset=dataset, + logger=trainer_logger, + variable_source=variable_source, + ) + + trainer.setup_mixer(self._mixer, self._mixer_optimizer) + + return trainer \ No newline at end of file diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py new file mode 100644 index 000000000..aa9c5f363 --- /dev/null +++ b/mava/systems/tf/value_decomposition/training.py @@ -0,0 +1,362 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Value Decomposition trainer implementation.""" + +import copy +import time +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tree +import trfl +from acme.tf import losses +from acme.tf import utils as tf2_utils +from acme.utils import loggers + +import mava +from mava import types as mava_types +from mava.adders.reverb.base import Trajectory +from mava.components.tf.losses.sequence import recurrent_n_step_critic_loss +from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor +from mava.systems.tf.madqn.training import MADQNRecurrentTrainer +from mava.systems.tf.variable_utils import VariableClient +from mava.utils import training_utils as train_utils +from mava.utils.sort_utils import sort_str_num + +train_utils.set_growing_gpu_memory() + + +class ValueDecompositionRecurrentTrainer(MADQNRecurrentTrainer): + """MADQN trainer. + This is the trainer component of a MADDPG system. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + agents: List[str], + agent_types: List[str], + value_networks: Dict[str, snt.Module], + target_value_networks: Dict[str, snt.Module], + optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + discount: float, + target_averaging: bool, + target_update_period: int, + target_update_rate: float, + dataset: tf.data.Dataset, + observation_networks: Dict[str, snt.Module], + target_observation_networks: Dict[str, snt.Module], + variable_client: VariableClient, + counts: Dict[str, Any], + agent_net_keys: Dict[str, str], + max_gradient_norm: float = None, + logger: loggers.Logger = None, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + ): + """Initialise MADDPG trainer + Args: + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: + optimizer(s) for updating policy networks. + critic_optimizer: + optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. + """ + + super().__init__( + agents=agents, + agent_types=agent_types, + value_networks=value_networks, + target_value_networks=target_value_networks, + optimizer=optimizer, + discount=discount, + target_averaging=target_averaging, + target_update_period=target_update_period, + target_update_rate=target_update_rate, + dataset=dataset, + observation_networks=observation_networks, + target_observation_networks=target_observation_networks, + variable_client=variable_client, + counts=counts, + agent_net_keys=agent_net_keys, + max_gradient_norm=max_gradient_norm, + logger=logger, + learning_rate_scheduler_fn=learning_rate_scheduler_fn, + ) + + self._mixer = None + self._target_mixer = None + self._mixer_optimizer = None + + def setup_mixer(self, mixer: snt.Module, mixer_optimizer: snt.Module): + self._mixer = mixer + self._target_mixer = copy.deepcopy(mixer) + self._mixer_optimizer = mixer_optimizer + + def _update_target_networks(self) -> None: + """Update the target networks using either target averaging or + by directy copying the weights of the online networks every few steps.""" + + online_variables = [] + target_variables = [] + for key in self.unique_net_keys: + # Update target network. + online_variables += list(( + *self._observation_networks[key].variables, + *self._value_networks[key].variables, + )) + target_variables += list(( + *self._target_observation_networks[key].variables, + *self._target_value_networks[key].variables, + )) + # Add mixer variables + online_variables += list(( + *self._mixer.variables, + )) + target_variables += list(( + *self._target_mixer.variables, + )) + + if self._target_averaging: + assert 0.0 < self._target_update_rate < 1.0 + tau = self._target_update_rate + for src, dest in zip(online_variables, target_variables): + dest.assign(dest * (1.0 - tau) + src * tau) + else: + # Make online -> target network update ops. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(online_variables, target_variables): + dest.assign(src) + + self._num_steps.assign_add(1) + + # Forward pass that calculates loss. + def _forward(self, inputs: reverb.ReplaySample) -> None: + """Trainer forward pass + Args: + inputs: input data from the data table (transitions) + """ + # Convert to time major + data = tree.map_structure( + lambda v: tf.expand_dims(v, axis=0) if len(v.shape) <= 1 else v, inputs.data + ) + data = tf2_utils.batch_to_sequence(data) + + # Note (dries): The unused variable is start_of_episodes. + observations, actions, rewards, discounts, _, extras = ( + data.observations, + data.actions, + data.rewards, + data.discounts, + data.start_of_episode, + data.extras, + ) + + # Global environment state + if "s_t" in extras: + global_env_state = extras["s_t"] + else: + global_env_state = None + + # Get initial state for the LSTM from replay and + # extract the first state in the sequence. + core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) + target_core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) + + # TODO (dries): Take out all the data_points that does not need + # to be processed here at the start. Therefore it does not have + # to be done later on and saves processing time. + # NOTE (Claude) or do zeropadding mask + + self.value_losses: Dict[str, tf.Tensor] = {} + # Do forward passes through the networks and calculate the losses + with tf.GradientTape(persistent=True) as tape: + # Note (dries): We are assuming that only the policy network + # is recurrent and not the observation network. + obs_trans, target_obs_trans = self._transform_observations(observations) + + # Lists for stacking tensors later + chosen_action_q_value_all_agents = [] + max_action_q_value_all_agents = [] + reward_all_agents = [] + env_discount_all_agents = [] + for agent in self._agents: + agent_key = self._agent_net_keys[agent] + + # Double Q-learning + q_tm1_values, _ = snt.static_unroll( + self._value_networks[agent_key], obs_trans[agent], core_state[agent][0] + ) + # Q-value of the action taken by agent + chosen_action_q_value = trfl.batched_index( + q_tm1_values, actions[agent] + ) + + + # Q-value of the next state + q_t_selector = tf.where( + tf.cast(observations[agent].legal_actions, 'bool'), + q_tm1_values, -999999999 + ) + q_t_values, _ = snt.static_unroll( + self._target_value_networks[agent_key], + target_obs_trans[agent], + target_core_state[agent][0] + ) + max_action = tf.argmax(q_t_selector, axis=-1) + max_action_q_value = trfl.batched_index( + q_t_values, + max_action + ) + + + # Append agent values to lists + chosen_action_q_value_all_agents.append(chosen_action_q_value) + max_action_q_value_all_agents.append(max_action_q_value) + reward_all_agents.append(rewards[agent]) + env_discount_all_agents.append(discounts[agent]) + + # Stack list of tensors into tensor with trailing agent dim + chosen_action_q_value_all_agents = tf.stack( + chosen_action_q_value_all_agents, axis=-1 + ) # shape=(T,B, Num_Agents) + max_action_q_value_all_agents = tf.stack( + max_action_q_value_all_agents, axis=-1 + ) # shape=(T,B, Num_Agents) + reward_all_agents = tf.stack(reward_all_agents, axis=-1) + env_discount_all_agents = tf.stack(env_discount_all_agents, axis=-1) + + # Mixing + chosen_action_q_value_all_agents = self._mixer( + chosen_action_q_value_all_agents, + states=global_env_state, + ) + max_action_q_value_all_agents = self._target_mixer( + max_action_q_value_all_agents, + states=global_env_state + ) + # NOTE Team reward is just the mean over agents indevidual rewards + reward_all_agents = tf.reduce_mean( + reward_all_agents, axis=-1, keepdims=True + ) + # NOTE We assume all agents have the same env discount since + # it is a team game + env_discount_all_agents = tf.reduce_mean( + env_discount_all_agents, axis=-1, keepdims=True + ) + + # Cast the additional discount to match + # the environment discount dtype. + discount = tf.cast(self._discount, dtype=discounts[agent].dtype) + pcont = discount * env_discount_all_agents + + # Bellman target + target = tf.stop_gradient( + reward_all_agents[:-1] + pcont[:-1] * max_action_q_value_all_agents[1:] + ) + + # Temporal difference error and loss. + td_error = target - chosen_action_q_value_all_agents[:-1] + + # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. + value_loss = 0.5 * tf.square(td_error) + value_loss = tf.reduce_mean(value_loss) + + # TODO zero padding mask + + self.value_losses = {agent: value_loss for agent in self._agents} + self.mixer_loss = value_loss + + self.tape = tape + + # Backward pass that calculates gradients and updates network. + def _backward(self) -> None: + """Trainer backward pass updating network parameters""" + + # Calculate the gradients and update the networks + value_losses = self.value_losses + mixer_loss = self.mixer_loss + tape = self.tape + for agent in self._trainer_agent_list: + agent_key = self._agent_net_keys[agent] + + # Get trainable variables. + variables = ( + self._observation_networks[agent_key].trainable_variables + + self._value_networks[agent_key].trainable_variables + ) + + # Compute gradients. + # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient + # on a persistent tape inside its context is significantly less efficient + # than calling it outside the context." caused by losses.dpg, which calls + # tape.gradient. + gradients = tape.gradient(value_losses[agent], variables) + + # Maybe clip gradients. + gradients = tf.clip_by_global_norm( + gradients, self._max_gradient_norm + )[0] + + # Apply gradients. + self._optimizers[agent_key].apply(gradients, variables) + + # TODO (Claude) what happens when there are multiple trainers @Dries + # Mixer + mixer_variables = self._mixer.trainable_variables + + gradients = tape.gradient(mixer_loss, mixer_variables) + + # Clip gradients. + gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] + + # Apply gradients. + if mixer_variables: + tf.print("OPTIMIZING MIXER") + self._mixer_optimizer.apply(gradients, mixer_variables) + + train_utils.safe_del(self, "tape") \ No newline at end of file diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index eb2994cba..700a3579a 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -19,7 +19,8 @@ import numpy as np from pettingzoo.utils import BaseParallelWraper from supersuit.utils.base_aec_wrapper import BaseWrapper - +import dm_env +from mava.types import OLT from mava.types import Action, Observation, Reward from mava.utils.wrapper_utils import RunningMeanStd from mava.wrappers.env_wrappers import ParallelEnvWrapper, SequentialEnvWrapper @@ -368,14 +369,77 @@ class ConcatAgentIdToObservation: self.possible_agents. """ - def __init__(self): - pass + def __init__(self, environment): + self._environment = environment + self._num_agents = len(environment.possible_agents) def reset(self): - pass + timestep, extras = self._environment.reset() + old_observations = timestep.observation + + new_observations = {} + + for agent_id, agent in enumerate(self._environment.possible_agents): + agent_olt = old_observations[agent] + + agent_observation = agent_olt.observation + agent_one_hot = np.zeros(self._num_agents, dtype = agent_observation.dtype) + agent_one_hot[agent_id] = 1 + + new_observations[agent] = OLT( + observation = np.concatenate([agent_one_hot, agent_observation]), + legal_actions= agent_olt.legal_actions, + terminal=agent_olt.terminal + ) + + return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras + def step(self, actions: Dict) -> Any: - pass + timestep, extras = self._environment.step(actions) + + old_observations = timestep.observation + new_observations = {} + for agent_id, agent in enumerate(self._environment.possible_agents): + agent_olt = old_observations[agent] + + agent_observation = agent_olt.observation + agent_one_hot = np.zeros(self._num_agents, dtype = agent_observation.dtype) + agent_one_hot[agent_id] = 1 + + new_observations[agent] = OLT( + observation = np.concatenate([agent_one_hot, agent_observation]), + legal_actions=agent_olt.legal_actions, + terminal=agent_olt.terminal + ) + + + return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras + + def observation_spec(self) -> Dict[str, OLT]: + """Observation spec. + + Returns: + types.Observation: spec for environment. + """ + timestep, extras = self.reset() + observations = timestep.observation + return observations + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment. + + Args: + name (str): attribute. + + Returns: + Any: return attribute from env or underlying env. + """ + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) + class ConcatPrevActionToObservation: """Concat one-hot vector of agent prev_action to obs. @@ -384,12 +448,21 @@ class ConcatPrevActionToObservation: TODO support continuous actions. """ - - def __init__(self): + + def __init__(self, environment): pass + self._env + self._prev_act + self._num_agents + def reset(self): + prev_act = zero_vector pass def step(self, actions: Dict) -> Any: - pass + timestep, extras = self.env.step(actions) + obs = concat(self.prev_actions, obs) + + self.prev_actions = actions + return timestep, extras diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index 9dc19f6f2..a9b06627d 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -35,7 +35,7 @@ class SMACWrapper(ParallelEnvWrapper): def __init__( self, environment: StarCraft2Env, - return_state_info: bool = False, + return_state_info: bool = True, ): """Constructor for parallel PZ wrapper. @@ -279,8 +279,7 @@ def get_stats(self) -> Optional[Dict]: Returns: extra stats to be logged. """ - pass - # return {"win_rate": self._info["win_rate"]} + return self._info @property def agents(self) -> List: From 7d2be2ec3537b73064197b633a5fd0e26492358d Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Thu, 20 Jan 2022 15:49:12 +0200 Subject: [PATCH 06/56] Added previous agent actions to wrapper --- .../smac/feedforward/decentralised/test.py | 8 +- mava/wrappers/env_preprocess_wrappers.py | 86 ++++++++++++++++--- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/examples/smac/feedforward/decentralised/test.py b/examples/smac/feedforward/decentralised/test.py index 6cf7c8aa7..f5cf84d45 100644 --- a/examples/smac/feedforward/decentralised/test.py +++ b/examples/smac/feedforward/decentralised/test.py @@ -1,7 +1,7 @@ from smac.env import StarCraft2Env from mava.wrappers import SMACWrapper from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation - +from mava.wrappers.env_preprocess_wrappers import ConcatPrevActionToObservation import numpy as np env = StarCraft2Env(map_name="3m") @@ -10,12 +10,16 @@ env = ConcatAgentIdToObservation(env) +env = ConcatPrevActionToObservation(env) + spec = env.action_spec() +# for agent in spec: +# print(spec[agent].num_values) spec = env.observation_spec() res = env.reset() -actions = {"agent_0": 1, "agent_1": 1, "agent_2": 1} +actions = {"agent_0": 1, "agent_1": 2, "agent_2": 3} res = env.step(actions) diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index 700a3579a..9f462d2a6 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -448,21 +448,81 @@ class ConcatPrevActionToObservation: TODO support continuous actions. """ - - def __init__(self, environment): - pass - - self._env - self._prev_act - self._num_agents + # Need to get the size of the action space of each agent + def __init__(self, environment): + self._environment = environment + def reset(self): - prev_act = zero_vector - pass + # Previous actions needs to be somethings like a dictionary containing zero vectors of + # length of the permitted action space per agent + self._prev_actions = {} + action_spec = self._environment.action_spec() + for agent in action_spec: + self._prev_actions[agent] = np.zeros(action_spec[agent].num_values) + + timestep, extras = self._environment.reset() + old_observations = timestep.observation + + new_observations = {} + #TODO double check this, because possible agents could shrink + for agent in self._environment.possible_agents: + agent_olt = old_observations[agent] + + agent_observation = agent_olt.observation + agent_one_hot_action = self._prev_actions[agent] + + new_observations[agent] = OLT( + observation = np.concatenate([agent_one_hot_action, agent_observation]), + legal_actions= agent_olt.legal_actions, + terminal=agent_olt.terminal + ) + + return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras def step(self, actions: Dict) -> Any: - timestep, extras = self.env.step(actions) - obs = concat(self.prev_actions, obs) + timestep, extras = self._environment.step(actions) + old_observations = timestep.observation + + new_observations = {} + + for agent in self._environment.possible_agents: + agent_olt = old_observations[agent] + + agent_observation = agent_olt.observation + agent_one_hot_action = self._prev_actions[agent] + agent_one_hot_action[actions[agent]] = 1 + + new_observations[agent] = OLT( + observation = np.concatenate([agent_one_hot_action, agent_observation]), + legal_actions= agent_olt.legal_actions, + terminal=agent_olt.terminal + ) - self.prev_actions = actions - return timestep, extras + + self._prev_actions = actions + return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras + + def observation_spec(self) -> Dict[str, OLT]: + """Observation spec. + + Returns: + types.Observation: spec for environment. + """ + timestep, extras = self.reset() + observations = timestep.observation + return observations + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment. + + Args: + name (str): attribute. + + Returns: + Any: return attribute from env or underlying env. + """ + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) \ No newline at end of file From 23b438a6f2b6cf862fbae1990ab6a1c182041f71 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 21 Jan 2022 11:25:13 +0200 Subject: [PATCH 07/56] Working QMIX. --- .../smac/recurrent/decentralised/run_qmix.py | 8 +++--- .../smac/recurrent/decentralised/run_vdn.py | 2 +- mava/systems/tf/madqn/builder.py | 5 ++++ mava/systems/tf/madqn/system.py | 5 +++- mava/systems/tf/madqn/training.py | 26 +++++++++++-------- mava/systems/tf/value_decomposition/system.py | 6 +++++ .../tf/value_decomposition/training.py | 1 - mava/wrappers/env_preprocess_wrappers.py | 22 ++++------------ 8 files changed, 41 insertions(+), 34 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/decentralised/run_qmix.py index fe0b8b3fa..00e1cd4c4 100644 --- a/examples/smac/recurrent/decentralised/run_qmix.py +++ b/examples/smac/recurrent/decentralised/run_qmix.py @@ -30,7 +30,7 @@ from mava.utils.loggers import logger_utils from mava.systems.tf.value_decomposition.mixer import QMIX - +from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation from smac.env import StarCraft2Env from mava.wrappers import SMACWrapper @@ -38,7 +38,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + "8m_vs_9m", "Starcraft 2 micromanagement map name (str).", ) @@ -52,6 +52,8 @@ def smac_env_factory(env_name="3m", evaluation = False): env = StarCraft2Env(map_name=env_name) env = SMACWrapper(env) + env = ConcatPrevActionToObservation(env) + env = ConcatAgentIdToObservation(env) return env @@ -92,7 +94,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=3e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 diff --git a/examples/smac/recurrent/decentralised/run_vdn.py b/examples/smac/recurrent/decentralised/run_vdn.py index 7c486533d..99f196c77 100644 --- a/examples/smac/recurrent/decentralised/run_vdn.py +++ b/examples/smac/recurrent/decentralised/run_vdn.py @@ -90,7 +90,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=6e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index 1b5b79521..dce3b3f33 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -519,6 +519,7 @@ def make_trainer( target_averaging = self._config.target_averaging target_update_rate = self._config.target_update_rate + print("4") # Create variable client variables = {} set_keys = [] @@ -535,6 +536,7 @@ def make_trainer( else: get_keys.append(f"{net_key}_{net_type_key}") + print("5") variables = self.create_counter_variables(variables) count_names = [ @@ -548,6 +550,7 @@ def make_trainer( get_keys.extend(count_names) counts = {name: variables[name] for name in count_names} + print("6") variable_client = variable_utils.VariableClient( client=variable_source, variables=variables, @@ -556,9 +559,11 @@ def make_trainer( update_period=10, ) + print("7") # Get all the initial variables variable_client.get_all_and_wait() + print("8") # Convert network keys for the trainer. trainer_agents = self._agents[: len(trainer_table_entry)] trainer_agent_net_keys = { diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 065c7ac7e..5babe2295 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -605,7 +605,7 @@ def trainer( Returns: system trainer. """ - + print("$$$$$$$$$$$$$$$$$") # create logger trainer_logger_config = {} if self._logger_config and "trainer" in self._logger_config: @@ -614,11 +614,14 @@ def trainer( trainer_id, **trainer_logger_config ) + print("1") # Create the system networks = self.create_system() + print("2") dataset = self._builder.make_dataset_iterator(replay, trainer_id) + print("3") return self._builder.make_trainer( networks=networks, trainer_networks=self._trainer_networks[trainer_id], diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 4046d64b8..f7a138310 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -237,7 +237,6 @@ def _transform_observations( o_t[agent] = tree.map_structure(tf.stop_gradient, o_t[agent]) return o_tm1, o_t - @tf.function def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -247,21 +246,24 @@ def _step( losses """ - # Update the target networks - self._update_target_networks() - # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) - self._forward(sample) + self._forward_backward(sample) - self._backward() + # Update the target networks + self._update_target_networks() # Log losses per agent return train_utils.map_losses_per_agent_value( self.value_losses ) + @tf.function + def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: + self._forward(inputs) + self._backward() + # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass @@ -445,8 +447,6 @@ def __init__( one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. """ - #self._bootstrap_n = bootstrap_n - self._agents = agents self._agent_type = agent_types self._agent_net_keys = agent_net_keys @@ -522,6 +522,11 @@ def __init__( # fill the replay buffer. self._timestamp: Optional[float] = None + def step(self) -> None: + """trainer step to update the parameters of the agents in the system""" + + raise NotImplementedError("A trainer statistics wrapper should overwrite this.") + def _transform_observations( self, observations: Dict[str, mava_types.OLT] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: @@ -602,7 +607,7 @@ def _step( Returns: losses """ - # Draw a batch of data from replay. + # # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) self._forward(sample) @@ -702,7 +707,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # TODO zero padding mask - self.value_losses[agent] = tf.reduce_mean(value_loss, axis=0) + self.value_losses[agent] = tf.reduce_mean(value_loss) self.tape = tape @@ -722,7 +727,6 @@ def _backward(self) -> None: + self._value_networks[agent_key].trainable_variables ) - # Compute gradients. # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient # on a persistent tape inside its context is significantly less efficient diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 929a8a608..171d177d1 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -249,6 +249,8 @@ def trainer( system trainer. """ + print("##########") + print("1") # create logger trainer_logger_config = {} if self._logger_config and "trainer" in self._logger_config: @@ -257,11 +259,14 @@ def trainer( trainer_id, **trainer_logger_config ) + print("2") # Create the system networks = self.create_system() + print("3") dataset = self._builder.make_dataset_iterator(replay, trainer_id) + print("4") trainer = self._builder.make_trainer( networks=networks, trainer_networks=self._trainer_networks[trainer_id], @@ -271,6 +276,7 @@ def trainer( variable_source=variable_source, ) + print("5") trainer.setup_mixer(self._mixer, self._mixer_optimizer) return trainer \ No newline at end of file diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index aa9c5f363..c499fc7ff 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -356,7 +356,6 @@ def _backward(self) -> None: # Apply gradients. if mixer_variables: - tf.print("OPTIMIZING MIXER") self._mixer_optimizer.apply(gradients, mixer_variables) train_utils.safe_del(self, "tape") \ No newline at end of file diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index 9f462d2a6..b2a34fe6a 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -453,24 +453,16 @@ class ConcatPrevActionToObservation: def __init__(self, environment): self._environment = environment - def reset(self): - # Previous actions needs to be somethings like a dictionary containing zero vectors of - # length of the permitted action space per agent - self._prev_actions = {} - action_spec = self._environment.action_spec() - for agent in action_spec: - self._prev_actions[agent] = np.zeros(action_spec[agent].num_values) - + def reset(self): timestep, extras = self._environment.reset() old_observations = timestep.observation - + action_spec = self._environment.action_spec() new_observations = {} #TODO double check this, because possible agents could shrink for agent in self._environment.possible_agents: agent_olt = old_observations[agent] - agent_observation = agent_olt.observation - agent_one_hot_action = self._prev_actions[agent] + agent_one_hot_action = np.zeros(action_spec[agent].num_values, dtype=np.float32) new_observations[agent] = OLT( observation = np.concatenate([agent_one_hot_action, agent_observation]), @@ -483,14 +475,12 @@ def reset(self): def step(self, actions: Dict) -> Any: timestep, extras = self._environment.step(actions) old_observations = timestep.observation - + action_spec = self._environment.action_spec() new_observations = {} - for agent in self._environment.possible_agents: agent_olt = old_observations[agent] - agent_observation = agent_olt.observation - agent_one_hot_action = self._prev_actions[agent] + agent_one_hot_action = np.zeros(action_spec[agent].num_values, dtype=np.float32) agent_one_hot_action[actions[agent]] = 1 new_observations[agent] = OLT( @@ -499,8 +489,6 @@ def step(self, actions: Dict) -> Any: terminal=agent_olt.terminal ) - - self._prev_actions = actions return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras def observation_spec(self) -> Dict[str, OLT]: From ee7b6fdac60147203bdd824f5a42d99acf9a5c44 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 21 Jan 2022 11:35:54 +0200 Subject: [PATCH 08/56] Small fixes in trainer. --- mava/systems/tf/madqn/training.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index f7a138310..bf9cc2788 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -249,21 +249,25 @@ def _step( # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) - self._forward_backward(sample) - - # Update the target networks - self._update_target_networks() + losses = self._forward_backward(sample) # Log losses per agent return train_utils.map_losses_per_agent_value( - self.value_losses + losses ) @tf.function def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: + self._forward(inputs) + self._backward() + # Update the target networks + self._update_target_networks() + + return self.value_losses + # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass From 85cc19ed2229e2cdfcccbbaa903134c15740e117 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 21 Jan 2022 11:39:52 +0200 Subject: [PATCH 09/56] More small fixes in trainer. --- mava/systems/tf/madqn/training.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index bf9cc2788..bb884be07 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -258,7 +258,7 @@ def _step( @tf.function def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: - + self._forward(inputs) self._backward() @@ -602,7 +602,6 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] """Depricated""" pass - # @tf.function def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -611,20 +610,28 @@ def _step( Returns: losses """ - # # Draw a batch of data from replay. + + # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) - self._forward(sample) + losses = self._forward_backward(sample) + + # Log losses per agent + return train_utils.map_losses_per_agent_value( + losses + ) + + @tf.function + def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: + + self._forward(inputs) self._backward() # Update the target networks self._update_target_networks() - # Log losses per agent - return train_utils.map_losses_per_agent_value( - self.value_losses - ) + return self.value_losses # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: From 737b020bef6db055cd0f5af9e6b3db60901cd1d5 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 24 Jan 2022 11:41:59 +0200 Subject: [PATCH 10/56] Benchmarking ready. --- .../recurrent/decentralised/run_madqn.py | 3 +- .../smac/recurrent/decentralised/run_madqn.py | 23 +++++------ .../smac/recurrent/decentralised/run_qmix.py | 30 ++++++-------- .../smac/recurrent/decentralised/run_vdn.py | 24 +++++------- mava/systems/tf/madqn/networks.py | 2 +- mava/systems/tf/madqn/system.py | 4 -- mava/systems/tf/madqn/training.py | 23 +++++------ mava/systems/tf/value_decomposition/system.py | 39 +++++++++---------- .../tf/value_decomposition/training.py | 18 ++------- mava/utils/environments/smac_utils.py | 28 +++++++++++++ mava/wrappers/debugging_envs.py | 29 +++++++++++--- mava/wrappers/smac.py | 4 +- 12 files changed, 120 insertions(+), 107 deletions(-) create mode 100644 mava/utils/environments/smac_utils.py diff --git a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py index 1a9381997..889534147 100644 --- a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py +++ b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py @@ -57,11 +57,12 @@ def main(_: Any) -> None: debugging_utils.make_environment, env_name=FLAGS.env_name, action_space=FLAGS.action_space, + num_agents=10 ) # Networks. network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, archecture_type=ArchitectureType.recurrent + madqn.make_default_networks, architecture_type=ArchitectureType.recurrent ) # Checkpointer appends "Checkpoints" to checkpoint_dir diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index b2391c24f..63a91022c 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -29,15 +29,15 @@ from mava.components.tf.modules.exploration.exploration_scheduling import ( LinearExplorationScheduler, ) +from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy from mava.systems.tf import madqn from mava.utils import lp_utils from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils from mava.utils.enums import ArchitectureType +from mava.utils.environments.smac_utils import make_environment -from smac.env import StarCraft2Env -from mava.wrappers import SMACWrapper FLAGS = flags.FLAGS @@ -54,18 +54,12 @@ ) flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") -def smac_env_factory(env_name="3m", evaluation = False): - env = StarCraft2Env(map_name=env_name) - env = SMACWrapper(env) - - return env - def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - smac_env_factory, env_name=FLAGS.map_name + make_environment, env_name=FLAGS.map_name ) # Networks. @@ -95,7 +89,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=8e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -103,12 +97,15 @@ def main(_: Any) -> None: checkpoint_subpath=checkpoint_dir, batch_size=32, executor_variable_update_period=200, - target_update_period=200, + target_update_period=100, max_gradient_norm=20.0, sequence_length=60, period=60, - min_replay_size=100, - max_replay_size=4000, + min_replay_size=32, + max_replay_size=5000, + samples_per_insert=1, + evaluator_interval={"executor_episodes": 2}, + termination_condition={"executor_steps": 3_000_000}, trainer_fn=madqn.training.MADQNRecurrentTrainer, executor_fn=madqn.execution.MADQNRecurrentExecutor, ).build() diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/decentralised/run_qmix.py index 00e1cd4c4..032caeb69 100644 --- a/examples/smac/recurrent/decentralised/run_qmix.py +++ b/examples/smac/recurrent/decentralised/run_qmix.py @@ -28,17 +28,13 @@ from mava.systems.tf import value_decomposition from mava.utils import lp_utils from mava.utils.loggers import logger_utils - -from mava.systems.tf.value_decomposition.mixer import QMIX -from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation -from smac.env import StarCraft2Env -from mava.wrappers import SMACWrapper +from mava.utils.environments.smac_utils import make_environment FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "8m_vs_9m", + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -49,20 +45,14 @@ ) flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") -def smac_env_factory(env_name="3m", evaluation = False): - env = StarCraft2Env(map_name=env_name) - env = SMACWrapper(env) - env = ConcatPrevActionToObservation(env) - env = ConcatAgentIdToObservation(env) - return env def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - smac_env_factory, env_name=FLAGS.map_name + make_environment, map_name=FLAGS.map_name ) # Networks. @@ -84,17 +74,15 @@ def main(_: Any) -> None: time_delta=log_every, ) - num_agents = len(environment_factory().possible_agents) - # distributed program program = value_decomposition.ValueDecomposition( environment_factory=environment_factory, network_factory=network_factory, - mixer=QMIX(num_agents=num_agents), + mixer="qmix", logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=8e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -102,12 +90,16 @@ def main(_: Any) -> None: checkpoint_subpath=checkpoint_dir, batch_size=32, executor_variable_update_period=200, - target_update_period=200, + target_update_period=100, max_gradient_norm=20.0, sequence_length=60, period=60, - min_replay_size=100, + min_replay_size=32, max_replay_size=5000, + samples_per_insert=1, + evaluator_interval={"executor_episodes": 2}, + termination_condition={"executor_steps": 3_000_000} + ).build() # launch diff --git a/examples/smac/recurrent/decentralised/run_vdn.py b/examples/smac/recurrent/decentralised/run_vdn.py index 99f196c77..e279148d0 100644 --- a/examples/smac/recurrent/decentralised/run_vdn.py +++ b/examples/smac/recurrent/decentralised/run_vdn.py @@ -28,11 +28,8 @@ from mava.systems.tf import value_decomposition from mava.utils import lp_utils from mava.utils.loggers import logger_utils +from mava.utils.environments.smac_utils import make_environment -from mava.systems.tf.value_decomposition.mixer import VDN - -from smac.env import StarCraft2Env -from mava.wrappers import SMACWrapper FLAGS = flags.FLAGS @@ -49,18 +46,12 @@ ) flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") -def smac_env_factory(env_name="3m", evaluation = False): - env = StarCraft2Env(map_name=env_name) - env = SMACWrapper(env) - - return env - def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment environment_factory = functools.partial( - smac_env_factory, env_name=FLAGS.map_name + make_environment, env_name=FLAGS.map_name ) # Networks. @@ -86,11 +77,11 @@ def main(_: Any) -> None: program = value_decomposition.ValueDecomposition( environment_factory=environment_factory, network_factory=network_factory, - mixer=VDN(), + mixer="vdn", logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=8e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -98,12 +89,15 @@ def main(_: Any) -> None: checkpoint_subpath=checkpoint_dir, batch_size=32, executor_variable_update_period=200, - target_update_period=200, + target_update_period=100, max_gradient_norm=20.0, sequence_length=60, period=60, - min_replay_size=100, + min_replay_size=32, max_replay_size=5000, + samples_per_insert=1, + termination_condition={"executor_steps": 3_000_000}, + evaluator_interval={"executor_episodes": 2} ).build() # launch diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index 4a6e7a516..f2922c397 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -74,7 +74,7 @@ def make_default_networks( value_network_func = snt.Sequential elif architecture_type == ArchitectureType.recurrent: if not value_networks_layer_sizes: - value_networks_layer_sizes = (128, 64) + value_networks_layer_sizes = (64, 64) value_network_func = snt.DeepRNN assert value_networks_layer_sizes is not None diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 5babe2295..fc67b580b 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -605,7 +605,6 @@ def trainer( Returns: system trainer. """ - print("$$$$$$$$$$$$$$$$$") # create logger trainer_logger_config = {} if self._logger_config and "trainer" in self._logger_config: @@ -614,14 +613,11 @@ def trainer( trainer_id, **trainer_logger_config ) - print("1") # Create the system networks = self.create_system() - print("2") dataset = self._builder.make_dataset_iterator(replay, trainer_id) - print("3") return self._builder.make_trainer( networks=networks, trainer_networks=self._trainer_networks[trainer_id], diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index bb884be07..943d72aa1 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -602,6 +602,10 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] """Depricated""" pass + # @tf.function + # NOTE (Claude) The recurrent trainer does not start with tf.function + # It does start on SMAC 3m and debug env(num_agents=3) but not on any other SMAC maps. + # TODO (Claude) get tf.function to work. def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -614,25 +618,18 @@ def _step( # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) - losses = self._forward_backward(sample) - - # Log losses per agent - return train_utils.map_losses_per_agent_value( - losses - ) - - @tf.function - def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: - - self._forward(inputs) + self._forward(sample) self._backward() # Update the target networks self._update_target_networks() - return self.value_losses - + # Log losses per agent + return train_utils.map_losses_per_agent_value( + self.value_losses + ) + # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 171d177d1..84964e335 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -35,15 +35,14 @@ from mava.utils import enums from mava.utils.loggers import MavaLogger from mava.systems.tf.madqn import MADQN +from mava.systems.tf.value_decomposition.mixer import QMIX, VDN class ValueDecomposition(MADQN): """Value Decomposition systems.""" - """TODO: Implement faster adders to speed up training times when - using multiple trainers with non-shared weights.""" - def __init__( # noqa + def __init__( self, environment_factory: Callable[[bool], dm_env.Environment], network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], @@ -60,12 +59,6 @@ def __init__( # noqa trainer_fn: Type[ValueDecompositionRecurrentTrainer] = ValueDecompositionRecurrentTrainer, executor_fn: Type[MADQNRecurrentExecutor] = MADQNRecurrentExecutor, num_executors: int = 1, - trainer_networks: Union[ - Dict[str, List], enums.Trainer - ] = enums.Trainer.single_trainer, - network_sampling_setup: Union[ - List, enums.NetworkSampler - ] = enums.NetworkSampler.fixed_agent_networks, shared_weights: bool = True, environment_spec: mava_specs.MAEnvironmentSpec = None, discount: float = 0.99, @@ -77,7 +70,7 @@ def __init__( # noqa executor_variable_update_period: int = 200, min_replay_size: int = 100, max_replay_size: int = 5000, - samples_per_insert: Optional[float] = 32.0, + samples_per_insert: Optional[float] = 2.0, optimizer: Union[ snt.Optimizer, Dict[str, snt.Optimizer] ] = snt.optimizers.Adam(learning_rate=1e-4), @@ -94,7 +87,7 @@ def __init__( # noqa train_loop_fn_kwargs: Dict = {}, eval_loop_fn_kwargs: Dict = {}, termination_condition: Optional[Dict[str, int]] = None, - evaluator_interval: Optional[dict] = None, + evaluator_interval: Optional[dict] = {"executor_episodes": 2}, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the system @@ -199,8 +192,8 @@ def __init__( # noqa trainer_fn=trainer_fn, executor_fn=executor_fn, num_executors=num_executors, - trainer_networks=trainer_networks, - network_sampling_setup=network_sampling_setup, + trainer_networks=enums.Trainer.single_trainer, + network_sampling_setup=enums.NetworkSampler.fixed_agent_networks, shared_weights=shared_weights, environment_spec=environment_spec, discount=discount, @@ -230,6 +223,19 @@ def __init__( # noqa learning_rate_scheduler_fn=learning_rate_scheduler_fn, ) + if isinstance(mixer, str): + if mixer == "qmix": + env = environment_factory() + num_agents = len(env.possible_agents) + mixer = QMIX(num_agents) + del env + elif mixer == "vdn": + mixer = VDN() + else: + raise ValueError( + "Mixer not recognised. Should be either 'vdn' or 'qmix'" + ) + self._mixer = mixer self._mixer_optimizer = mixer_optimizer @@ -248,9 +254,6 @@ def trainer( Returns: system trainer. """ - - print("##########") - print("1") # create logger trainer_logger_config = {} if self._logger_config and "trainer" in self._logger_config: @@ -259,14 +262,11 @@ def trainer( trainer_id, **trainer_logger_config ) - print("2") # Create the system networks = self.create_system() - print("3") dataset = self._builder.make_dataset_iterator(replay, trainer_id) - print("4") trainer = self._builder.make_trainer( networks=networks, trainer_networks=self._trainer_networks[trainer_id], @@ -276,7 +276,6 @@ def trainer( variable_source=variable_source, ) - print("5") trainer.setup_mixer(self._mixer, self._mixer_optimizer) return trainer \ No newline at end of file diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index c499fc7ff..d9f3723aa 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -196,7 +196,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: data.extras, ) - # Global environment state + # Global environment state for mixer if "s_t" in extras: global_env_state = extras["s_t"] else: @@ -207,15 +207,9 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) target_core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) - # TODO (dries): Take out all the data_points that does not need - # to be processed here at the start. Therefore it does not have - # to be done later on and saves processing time. - # NOTE (Claude) or do zeropadding mask - - self.value_losses: Dict[str, tf.Tensor] = {} # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: - # Note (dries): We are assuming that only the policy network + # NOTE (Dries): We are assuming that only the valu network # is recurrent and not the observation network. obs_trans, target_obs_trans = self._transform_observations(observations) @@ -255,6 +249,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Append agent values to lists + # NOTE (Claude) appending to a list does not work in tf.function chosen_action_q_value_all_agents.append(chosen_action_q_value) max_action_q_value_all_agents.append(max_action_q_value) reward_all_agents.append(rewards[agent]) @@ -321,7 +316,7 @@ def _backward(self) -> None: value_losses = self.value_losses mixer_loss = self.mixer_loss tape = self.tape - for agent in self._trainer_agent_list: + for agent in self._agents: agent_key = self._agent_net_keys[agent] # Get trainable variables. @@ -331,10 +326,6 @@ def _backward(self) -> None: ) # Compute gradients. - # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient - # on a persistent tape inside its context is significantly less efficient - # than calling it outside the context." caused by losses.dpg, which calls - # tape.gradient. gradients = tape.gradient(value_losses[agent], variables) # Maybe clip gradients. @@ -345,7 +336,6 @@ def _backward(self) -> None: # Apply gradients. self._optimizers[agent_key].apply(gradients, variables) - # TODO (Claude) what happens when there are multiple trainers @Dries # Mixer mixer_variables = self._mixer.trainable_variables diff --git a/mava/utils/environments/smac_utils.py b/mava/utils/environments/smac_utils.py new file mode 100644 index 000000000..c94d4e00e --- /dev/null +++ b/mava/utils/environments/smac_utils.py @@ -0,0 +1,28 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation +from smac.env import StarCraft2Env +from mava.wrappers import SMACWrapper + +def make_environment(map_name="3m", concat_prev_actions=True, concat_agent_id=True, evaluation = False, random_seed=None): + env = StarCraft2Env(map_name=map_name, seed=random_seed) + env = SMACWrapper(env) + if concat_prev_actions: + env = ConcatPrevActionToObservation(env) + + if concat_agent_id: + env = ConcatAgentIdToObservation(env) + + return env \ No newline at end of file diff --git a/mava/wrappers/debugging_envs.py b/mava/wrappers/debugging_envs.py index b7d59b0c1..16ce51673 100644 --- a/mava/wrappers/debugging_envs.py +++ b/mava/wrappers/debugging_envs.py @@ -123,10 +123,16 @@ def _convert_observations( # TODO Handle legal actions better for continuous envs, # maybe have min and max for each action and clip the agents actions # accordingly - legals = np.ones( - _convert_to_spec(self._environment.action_spaces[agent]).shape, - dtype=self._environment.action_spaces[agent].dtype, - ) + if isinstance(self._environment.action_spaces[agent], spaces.Discrete): + legals = np.ones( + _convert_to_spec(self._environment.action_spaces[agent]).num_values, + dtype=self._environment.action_spaces[agent].dtype, + ) + else: + legals = np.ones( + _convert_to_spec(self._environment.action_spaces[agent]).shape, + dtype=self._environment.action_spaces[agent].dtype, + ) observation = np.array(observation, dtype=np.float32) observations[agent] = OLT( @@ -140,11 +146,24 @@ def _convert_observations( def observation_spec(self) -> Dict[str, OLT]: observation_specs = {} for agent in self._environment.agent_ids: + + # Legals spec + if isinstance(self._environment.action_spaces[agent], spaces.Discrete): + legals = np.ones( + _convert_to_spec(self._environment.action_spaces[agent]).num_values, + dtype=self._environment.action_spaces[agent].dtype, + ) + else: + legals = np.ones( + _convert_to_spec(self._environment.action_spaces[agent]).shape, + dtype=self._environment.action_spaces[agent].dtype, + ) + observation_specs[agent] = OLT( observation=_convert_to_spec( self._environment.observation_spaces[agent] ), - legal_actions=_convert_to_spec(self._environment.action_spaces[agent]), + legal_actions=legals, terminal=specs.Array((1,), np.float32), ) return observation_specs diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index a9b06627d..1c904da99 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -279,8 +279,8 @@ def get_stats(self) -> Optional[Dict]: Returns: extra stats to be logged. """ - return self._info - + return self._environment.get_stats() + @property def agents(self) -> List: """Agents still alive in env (not done). From ef8592c824e2ab64baf09589b05215ab6e1b3ee7 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 24 Jan 2022 13:43:56 +0200 Subject: [PATCH 11/56] Remove flatland import. --- .../smac/recurrent/decentralised/run_madqn.py | 25 +++--- .../smac/recurrent/decentralised/run_qmix.py | 21 +++-- .../smac/recurrent/decentralised/run_vdn.py | 17 ++-- mava/systems/tf/madqn/training.py | 86 ++++++++++--------- mava/wrappers/__init__.py | 5 +- 5 files changed, 80 insertions(+), 74 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 63a91022c..3d55dcc56 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -29,21 +29,25 @@ from mava.components.tf.modules.exploration.exploration_scheduling import ( LinearExplorationScheduler, ) -from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy from mava.systems.tf import madqn from mava.utils import lp_utils -from mava.utils.environments import pettingzoo_utils -from mava.utils.loggers import logger_utils from mava.utils.enums import ArchitectureType +from mava.utils.environments import pettingzoo_utils from mava.utils.environments.smac_utils import make_environment +from mava.utils.loggers import logger_utils +from mava.wrappers.env_preprocess_wrappers import ( + ConcatAgentIdToObservation, + ConcatPrevActionToObservation, +) - +SEQUENCE_LENGTH = 60 +MAP_NAME = "3m" FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + MAP_NAME, "Starcraft 2 micromanagement map name (str).", ) @@ -58,14 +62,11 @@ def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment - environment_factory = functools.partial( - make_environment, env_name=FLAGS.map_name - ) + environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, - architecture_type=ArchitectureType.recurrent + madqn.make_default_networks, architecture_type=ArchitectureType.recurrent ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -99,8 +100,8 @@ def main(_: Any) -> None: executor_variable_update_period=200, target_update_period=100, max_gradient_norm=20.0, - sequence_length=60, - period=60, + sequence_length=SEQUENCE_LENGTH, + period=SEQUENCE_LENGTH, min_replay_size=32, max_replay_size=5000, samples_per_insert=1, diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/decentralised/run_qmix.py index 032caeb69..ead655dde 100644 --- a/examples/smac/recurrent/decentralised/run_qmix.py +++ b/examples/smac/recurrent/decentralised/run_qmix.py @@ -27,14 +27,16 @@ ) from mava.systems.tf import value_decomposition from mava.utils import lp_utils -from mava.utils.loggers import logger_utils from mava.utils.environments.smac_utils import make_environment +from mava.utils.loggers import logger_utils +SEQUENCE_LENGTH = 60 +MAP_NAME = "3m" FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + MAP_NAME, "Starcraft 2 micromanagement map name (str).", ) @@ -46,14 +48,16 @@ flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") +log_msg = ( + f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"Executing: {command}" + os.linesep +) def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment - environment_factory = functools.partial( - make_environment, map_name=FLAGS.map_name - ) + environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. network_factory = lp_utils.partial_kwargs( @@ -92,14 +96,13 @@ def main(_: Any) -> None: executor_variable_update_period=200, target_update_period=100, max_gradient_norm=20.0, - sequence_length=60, - period=60, + sequence_length=SEQUENCE_LENGTH, + period=SEQUENCE_LENGTH, min_replay_size=32, max_replay_size=5000, samples_per_insert=1, evaluator_interval={"executor_episodes": 2}, - termination_condition={"executor_steps": 3_000_000} - + termination_condition={"executor_steps": 3_000_000}, ).build() # launch diff --git a/examples/smac/recurrent/decentralised/run_vdn.py b/examples/smac/recurrent/decentralised/run_vdn.py index e279148d0..20458240d 100644 --- a/examples/smac/recurrent/decentralised/run_vdn.py +++ b/examples/smac/recurrent/decentralised/run_vdn.py @@ -27,15 +27,16 @@ ) from mava.systems.tf import value_decomposition from mava.utils import lp_utils -from mava.utils.loggers import logger_utils from mava.utils.environments.smac_utils import make_environment +from mava.utils.loggers import logger_utils - +SEQUENCE_LENGTH = 60 +MAP_NAME = "3m" FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + MAP_NAME, "Starcraft 2 micromanagement map name (str).", ) @@ -50,9 +51,7 @@ def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment - environment_factory = functools.partial( - make_environment, env_name=FLAGS.map_name - ) + environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. network_factory = lp_utils.partial_kwargs( @@ -91,13 +90,13 @@ def main(_: Any) -> None: executor_variable_update_period=200, target_update_period=100, max_gradient_norm=20.0, - sequence_length=60, - period=60, + sequence_length=SEQUENCE_LENGTH, + period=SEQUENCE_LENGTH, min_replay_size=32, max_replay_size=5000, samples_per_insert=1, termination_condition={"executor_steps": 3_000_000}, - evaluator_interval={"executor_episodes": 2} + evaluator_interval={"executor_episodes": 2}, ).build() # launch diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 943d72aa1..397cee0bb 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -172,9 +172,9 @@ def __init__( self._system_network_variables["observations"][ agent_key ] = self._target_observation_networks[agent_key].variables - self._system_network_variables["values"][ + self._system_network_variables["values"][agent_key] = self._value_networks[ agent_key - ] = self._value_networks[agent_key].variables + ].variables # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and @@ -252,9 +252,7 @@ def _step( losses = self._forward_backward(sample) # Log losses per agent - return train_utils.map_losses_per_agent_value( - losses - ) + return train_utils.map_losses_per_agent_value(losses) @tf.function def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: @@ -318,7 +316,12 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Value loss. value_loss, _ = trfl.double_qlearning( - q_tm1, a_tm1[agent], r_t[agent], discount * d_t[agent], q_t_value, q_t_selector + q_tm1, + a_tm1[agent], + r_t[agent], + discount * d_t[agent], + q_t_value, + q_t_selector, ) self.value_losses[agent] = tf.reduce_mean(value_loss, axis=0) @@ -341,7 +344,6 @@ def _backward(self) -> None: + self._value_networks[agent_key].trainable_variables ) - # Compute gradients. # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient # on a persistent tape inside its context is significantly less efficient @@ -350,9 +352,7 @@ def _backward(self) -> None: gradients = tape.gradient(value_losses[agent], variables) # Maybe clip gradients. - gradients = tf.clip_by_global_norm( - gradients, self._max_gradient_norm - )[0] + gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] # Apply gradients. self._optimizers[agent_key].apply(gradients, variables) @@ -389,7 +389,7 @@ def _decay_lr(self, trainer_step: int) -> None: class MADQNRecurrentTrainer: """Recurrent MADQN trainer. - + This is the trainer component of a MADQN system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -463,7 +463,7 @@ def __init__( # Store online and target networks. self._value_networks = value_networks self._target_value_networks = target_value_networks - + # Ensure obs and target networks are sonnet modules self._observation_networks = { k: tf2_utils.to_sonnet_module(v) for k, v in observation_networks.items() @@ -604,7 +604,7 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] # @tf.function # NOTE (Claude) The recurrent trainer does not start with tf.function - # It does start on SMAC 3m and debug env(num_agents=3) but not on any other SMAC maps. + # It does start on SMAC 3m and debug env but not on any other SMAC maps. # TODO (Claude) get tf.function to work. def _step( self, @@ -626,10 +626,8 @@ def _step( self._update_target_networks() # Log losses per agent - return train_utils.map_losses_per_agent_value( - self.value_losses - ) - + return train_utils.map_losses_per_agent_value(self.value_losses) + # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass @@ -655,14 +653,16 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Get initial state for the LSTM from replay and # extract the first state in the sequence. core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) - target_core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) + target_core_state = tree.map_structure( + lambda s: s[0, :, :], extras["core_states"] + ) # TODO (dries): Take out all the data_points that does not need # to be processed here at the start. Therefore it does not have # to be done later on and saves processing time. self.value_losses: Dict[str, tf.Tensor] = {} - + # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: # Note (dries): We are assuming that only the policy network @@ -674,44 +674,48 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Double Q-learning q, _ = snt.static_unroll( - self._value_networks[agent_key], obs_trans[agent], core_state[agent][0] + self._value_networks[agent_key], + obs_trans[agent], + core_state[agent][0], ) - q_tm1 = q[:-1] # Chop off last timestep - q_t_selector = q[1:] # Chop off first timestep + q_tm1 = q[:-1] # Chop off last timestep + q_t_selector = q[1:] # Chop off first timestep q_t_value, _ = snt.static_unroll( - self._target_value_networks[agent_key], target_obs_trans[agent], target_core_state[agent][0] + self._target_value_networks[agent_key], + target_obs_trans[agent], + target_core_state[agent][0], ) - q_t_value = q_t_value[1:] # Chop off first timestep + q_t_value = q_t_value[1:] # Chop off first timestep # Legal action masking - q_t_selector = tf.where(tf.cast(observations[agent].legal_actions[1:], 'bool'), q_t_selector, -999999999) + q_t_selector = tf.where( + tf.cast(observations[agent].legal_actions[1:], "bool"), + q_t_selector, + -999999999, + ) # Cast the additional discount to match # the environment discount dtype. discount = tf.cast(self._discount, dtype=discounts[agent].dtype) # Flatten out time and batch dim - q_tm1, _ = train_utils.combine_dim( - q_tm1 - ) - q_t_selector, _ = train_utils.combine_dim( - q_t_selector - ) - q_t_value, _ = train_utils.combine_dim( - q_t_value - ) + q_tm1, _ = train_utils.combine_dim(q_tm1) + q_t_selector, _ = train_utils.combine_dim(q_t_selector) + q_t_value, _ = train_utils.combine_dim(q_t_value) a_tm1, _ = train_utils.combine_dim( - actions[agent][:-1] # Chop off last timestep + actions[agent][:-1] # Chop off last timestep ) r_t, _ = train_utils.combine_dim( - rewards[agent][:-1] # Chop off last timestep + rewards[agent][:-1] # Chop off last timestep ) d_t, _ = train_utils.combine_dim( - discounts[agent][:-1] # Chop off last timestep + discounts[agent][:-1] # Chop off last timestep ) # Value loss - value_loss, _ = trfl.double_qlearning(q_tm1, a_tm1, r_t, discount * d_t, q_t_value, q_t_selector) + value_loss, _ = trfl.double_qlearning( + q_tm1, a_tm1, r_t, discount * d_t, q_t_value, q_t_selector + ) # TODO zero padding mask @@ -743,9 +747,7 @@ def _backward(self) -> None: gradients = tape.gradient(value_losses[agent], variables) # Maybe clip gradients. - gradients = tf.clip_by_global_norm( - gradients, self._max_gradient_norm - )[0] + gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] # Apply gradients. self._optimizers[agent_key].apply(gradients, variables) @@ -777,4 +779,4 @@ def _decay_lr(self, trainer_step: int) -> None: """ train_utils.decay_lr( self._learning_rate_scheduler_fn, self._optimizers, trainer_step - ) \ No newline at end of file + ) diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index fcb03c1d6..d90c841a1 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -25,11 +25,12 @@ PettingZooParallelEnvWrapper, ) from mava.wrappers.robocup import RoboCupWrapper + +# from mava.wrappers.flatland import FlatlandEnvWrapper +from mava.wrappers.smac import SMACWrapper from mava.wrappers.system_trainer_statistics import ( DetailedTrainerStatistics, NetworkStatisticsActorCritic, NetworkStatisticsMixing, ScaledDetailedTrainerStatistics, ) -from mava.wrappers.flatland import FlatlandEnvWrapper -from mava.wrappers.smac import SMACWrapper From bb2ad1194d2047ceb96196d5e6fa4be270f14d82 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 24 Jan 2022 13:48:55 +0200 Subject: [PATCH 12/56] Small fix --- examples/smac/recurrent/decentralised/run_qmix.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/decentralised/run_qmix.py index ead655dde..40c005408 100644 --- a/examples/smac/recurrent/decentralised/run_qmix.py +++ b/examples/smac/recurrent/decentralised/run_qmix.py @@ -27,8 +27,8 @@ ) from mava.systems.tf import value_decomposition from mava.utils import lp_utils -from mava.utils.environments.smac_utils import make_environment from mava.utils.loggers import logger_utils +from mava.utils.environments.smac_utils import make_environment SEQUENCE_LENGTH = 60 MAP_NAME = "3m" @@ -48,16 +48,12 @@ flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") -log_msg = ( - f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " - f"Executing: {command}" + os.linesep -) - - def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment - environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) + environment_factory = functools.partial( + make_environment, map_name=FLAGS.map_name + ) # Networks. network_factory = lp_utils.partial_kwargs( @@ -102,7 +98,8 @@ def main(_: Any) -> None: max_replay_size=5000, samples_per_insert=1, evaluator_interval={"executor_episodes": 2}, - termination_condition={"executor_steps": 3_000_000}, + termination_condition={"executor_steps": 3_000_000} + ).build() # launch From 7435936c2613e077d559ea09844f6631f748c4f9 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Tue, 25 Jan 2022 11:21:43 +0200 Subject: [PATCH 13/56] Small fixes and clean-up. --- .../recurrent/decentralised/run_madqn.py | 16 +- .../smac/recurrent/decentralised/run_qmix.py | 16 +- mava/components/tf/modules/mixing/__init__.py | 7 +- mava/components/tf/modules/mixing/additive.py | 49 - mava/components/tf/modules/mixing/base.py | 50 - .../tf/modules/mixing/mixers.py} | 7 - .../components/tf/modules/mixing/monotonic.py | 92 -- .../tf/modules/stabilising/__init__.py | 7 +- .../components/tf/modules/stabilising/base.py | 33 - .../tf/modules/stabilising/fingerprints.py | 63 +- mava/components/tf/networks/__init__.py | 5 +- mava/components/tf/networks/additive.py | 30 - mava/components/tf/networks/fingerprints.py | 52 - mava/components/tf/networks/hypernetwork.py | 99 -- mava/components/tf/networks/monotonic.py | 86 -- mava/systems/tf/madqn/execution.py | 19 +- mava/systems/tf/madqn/system.py | 3 +- mava/systems/tf/madqn/training.py | 7 +- mava/systems/tf/qmix/README.md | 9 - mava/systems/tf/qmix/__init__.py | 19 - mava/systems/tf/qmix/builder.py | 179 --- mava/systems/tf/qmix/execution.py | 81 -- mava/systems/tf/qmix/networks.py | 66 - mava/systems/tf/qmix/system.py | 362 ------ mava/systems/tf/qmix/training.py | 305 ----- .../tf/value_decomposition/networks.py | 2 +- mava/systems/tf/value_decomposition/system.py | 2 +- .../tf/value_decomposition/training.py | 10 +- mava/systems/tf/vdn/README.md | 9 - mava/systems/tf/vdn/__init__.py | 19 - mava/systems/tf/vdn/builder.py | 172 --- mava/systems/tf/vdn/execution.py | 81 -- mava/systems/tf/vdn/networks.py | 66 - mava/systems/tf/vdn/system.py | 310 ----- mava/systems/tf/vdn/training.py | 245 ---- mava/utils/environments/flatland_utils.py | 58 +- mava/utils/environments/smac_utils.py | 2 + mava/wrappers/flatland.py | 1063 +++++++++-------- 38 files changed, 630 insertions(+), 3071 deletions(-) delete mode 100644 mava/components/tf/modules/mixing/additive.py delete mode 100644 mava/components/tf/modules/mixing/base.py rename mava/{systems/tf/value_decomposition/mixer.py => components/tf/modules/mixing/mixers.py} (93%) delete mode 100644 mava/components/tf/modules/mixing/monotonic.py delete mode 100644 mava/components/tf/modules/stabilising/base.py delete mode 100644 mava/components/tf/networks/additive.py delete mode 100644 mava/components/tf/networks/fingerprints.py delete mode 100644 mava/components/tf/networks/hypernetwork.py delete mode 100644 mava/components/tf/networks/monotonic.py delete mode 100644 mava/systems/tf/qmix/README.md delete mode 100644 mava/systems/tf/qmix/__init__.py delete mode 100644 mava/systems/tf/qmix/builder.py delete mode 100644 mava/systems/tf/qmix/execution.py delete mode 100644 mava/systems/tf/qmix/networks.py delete mode 100644 mava/systems/tf/qmix/system.py delete mode 100644 mava/systems/tf/qmix/training.py delete mode 100644 mava/systems/tf/vdn/README.md delete mode 100644 mava/systems/tf/vdn/__init__.py delete mode 100644 mava/systems/tf/vdn/builder.py delete mode 100644 mava/systems/tf/vdn/execution.py delete mode 100644 mava/systems/tf/vdn/networks.py delete mode 100644 mava/systems/tf/vdn/system.py delete mode 100644 mava/systems/tf/vdn/training.py diff --git a/examples/flatland/recurrent/decentralised/run_madqn.py b/examples/flatland/recurrent/decentralised/run_madqn.py index 0dca5c70b..e30090327 100644 --- a/examples/flatland/recurrent/decentralised/run_madqn.py +++ b/examples/flatland/recurrent/decentralised/run_madqn.py @@ -27,7 +27,7 @@ ) from mava.systems.tf import madqn from mava.utils import lp_utils -from mava.utils.environments.flatland_utils import flatland_env_factory +from mava.utils.environments.flatland_utils import make_environment from mava.utils.loggers import logger_utils from mava.utils.enums import ArchitectureType @@ -43,7 +43,7 @@ # flatland environment config -flatland_env_config: Dict = { +env_config: Dict = { "n_agents": 3, "x_dim": 30, "y_dim": 30, @@ -63,14 +63,14 @@ def main(_: Any) -> None: # Environment. environment_factory = functools.partial( - flatland_env_factory, env_config=flatland_env_config, include_agent_info=False + make_environment, **env_config ) # Networks. network_factory = lp_utils.partial_kwargs( madqn.make_default_networks, architecture_type=ArchitectureType.recurrent - ) + ) # Checkpointer appends "Checkpoints" to checkpoint_dir checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" @@ -100,13 +100,15 @@ def main(_: Any) -> None: executor_variable_update_period=200, target_update_period=200, max_gradient_norm=20.0, - sequence_length=20, - period=10, - min_replay_size=100, + sequence_length=70, + period=70, + min_replay_size=32, max_replay_size=5000, trainer_fn=madqn.training.MADQNRecurrentTrainer, executor_fn=madqn.execution.MADQNRecurrentExecutor, checkpoint_subpath=checkpoint_dir, + evaluator_interval={"executor_episodes": 2}, + termination_condition={"executor_steps": 3_000_000} ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/decentralised/run_qmix.py index 40c005408..4293672ed 100644 --- a/examples/smac/recurrent/decentralised/run_qmix.py +++ b/examples/smac/recurrent/decentralised/run_qmix.py @@ -30,8 +30,8 @@ from mava.utils.loggers import logger_utils from mava.utils.environments.smac_utils import make_environment -SEQUENCE_LENGTH = 60 -MAP_NAME = "3m" +SEQUENCE_LENGTH = 120 +MAP_NAME = "2s3z" FLAGS = flags.FLAGS flags.DEFINE_string( @@ -82,7 +82,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=8e-6 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -90,13 +90,13 @@ def main(_: Any) -> None: checkpoint_subpath=checkpoint_dir, batch_size=32, executor_variable_update_period=200, - target_update_period=100, + target_update_period=200, max_gradient_norm=20.0, - sequence_length=SEQUENCE_LENGTH, - period=SEQUENCE_LENGTH, + sequence_length=20, + period=10, min_replay_size=32, - max_replay_size=5000, - samples_per_insert=1, + max_replay_size=10_000, + samples_per_insert=32, evaluator_interval={"executor_episodes": 2}, termination_condition={"executor_steps": 3_000_000} diff --git a/mava/components/tf/modules/mixing/__init__.py b/mava/components/tf/modules/mixing/__init__.py index b1fea8417..425ad888d 100644 --- a/mava/components/tf/modules/mixing/__init__.py +++ b/mava/components/tf/modules/mixing/__init__.py @@ -13,8 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MARL system mixing modules.""" - -from mava.components.tf.modules.mixing.additive import AdditiveMixing -from mava.components.tf.modules.mixing.base import BaseMixingModule -from mava.components.tf.modules.mixing.monotonic import MonotonicMixing +"""Value decomposition mixing modules.""" +from mava.components.tf.modules.mixing.mixers import BaseMixer, VDN, QMIX diff --git a/mava/components/tf/modules/mixing/additive.py b/mava/components/tf/modules/mixing/additive.py deleted file mode 100644 index 1b3d8196b..000000000 --- a/mava/components/tf/modules/mixing/additive.py +++ /dev/null @@ -1,49 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import sonnet as snt - -from mava.components.tf.architectures import BaseArchitecture -from mava.components.tf.modules.mixing.base import BaseMixingModule -from mava.components.tf.networks.additive import AdditiveMixingNetwork - - -class AdditiveMixing(BaseMixingModule): - """Multi-agent monotonic mixing architecture.""" - - def __init__(self, architecture: BaseArchitecture) -> None: - """Initializes the mixer.""" - super(AdditiveMixing, self).__init__() - - self._architecture = architecture - self._agent_networks: Dict[str, snt.Module] = {} - - def _create_mixing_layer(self, name: str = "mixing") -> snt.Module: - # Instantiate additive mixing network - self._mixed_network = AdditiveMixingNetwork(name) - return self._mixed_network - - def create_system(self) -> Dict[str, Dict[str, snt.Module]]: - self._agent_networks[ - "agent_networks" - ] = self._architecture.create_actor_variables() - self._agent_networks["mixing"] = self._create_mixing_layer("mixing") - self._agent_networks["target_mixing"] = self._create_mixing_layer( - "target_mixing" - ) - - return self._agent_networks diff --git a/mava/components/tf/modules/mixing/base.py b/mava/components/tf/modules/mixing/base.py deleted file mode 100644 index 946786c2d..000000000 --- a/mava/components/tf/modules/mixing/base.py +++ /dev/null @@ -1,50 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from typing import Dict, Optional - -import sonnet as snt - -from mava import specs as mava_specs -from mava.components.tf.architectures import BaseArchitecture - -"""Base mixing interface for multi-agent RL systems""" - - -class BaseMixingModule: - """Base class for MARL mixing. - Objects which implement this interface provide a set of functions - to create systems that can perform value decomposition via a mixing - strategy between agents in a multi-agent RL system. - """ - - @abc.abstractmethod - def __init__( - self, - architecture: Optional[BaseArchitecture] = None, - environment_spec: Optional[mava_specs.MAEnvironmentSpec] = None, - agent_networks: Optional[Dict[str, snt.Module]] = None, - ) -> None: - """Initialise the mixer.""" - - @abc.abstractmethod - def _create_mixing_layer(self, name: str) -> snt.Module: - """Abstract function for adding an arbitrary mixing layer to a - given architecture.""" - - @abc.abstractmethod - def create_system(self) -> Dict[str, Dict[str, snt.Module]]: - """Create/update system architecture with specified mixing.""" diff --git a/mava/systems/tf/value_decomposition/mixer.py b/mava/components/tf/modules/mixing/mixers.py similarity index 93% rename from mava/systems/tf/value_decomposition/mixer.py rename to mava/components/tf/modules/mixing/mixers.py index 6565c7944..1e191859a 100644 --- a/mava/systems/tf/value_decomposition/mixer.py +++ b/mava/components/tf/modules/mixing/mixers.py @@ -13,13 +13,6 @@ def __init__(self): def __call__(self, agent_qs: tf.Tensor , states: tf.Tensor): return agent_qs - """Initialize Base Mixer class - Args: - agent_qs: Tensor containing the q-values of actions chosen by agents - states: Tensor containing global environment state. - """ - -@snt.allow_empty_variables class VDN(BaseMixer): """VDN mixing network.""" diff --git a/mava/components/tf/modules/mixing/monotonic.py b/mava/components/tf/modules/mixing/monotonic.py deleted file mode 100644 index 4d51e92b3..000000000 --- a/mava/components/tf/modules/mixing/monotonic.py +++ /dev/null @@ -1,92 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mixing for multi-agent RL systems""" -from typing import Dict, Union - -import sonnet as snt -import tensorflow as tf -from acme.tf import utils as tf2_utils - -from mava import specs as mava_specs -from mava.components.tf.architectures.decentralised import ( - DecentralisedValueActor, - DecentralisedValueActorCritic, -) -from mava.components.tf.modules.mixing import BaseMixingModule -from mava.components.tf.networks.monotonic import MonotonicMixingNetwork - - -class MonotonicMixing(BaseMixingModule): - """Multi-agent monotonic mixing architecture. - This is the component which can be used to add monotonic mixing to an underlying - agent architecture. It currently supports generalised monotonic mixing using - hypernetworks (1 or 2 layers) for control of decomposition parameters (QMix).""" - - def __init__( - self, - environment_spec: mava_specs.MAEnvironmentSpec, - architecture: Union[DecentralisedValueActor, DecentralisedValueActorCritic], - qmix_hidden_dim: int = 32, - num_hypernet_layers: int = 2, - hypernet_hidden_dim: int = 64, # Defaults to qmix_hidden_dim - ) -> None: - """Initializes the mixer. - Args: - architecture: the BaseArchitecture used. - """ - super(MonotonicMixing, self).__init__() - - assert hasattr( - architecture, "_n_agents" - ), "Architecture doesn't have _n_agents." - self._environment_spec = environment_spec - self._qmix_hidden_dim = qmix_hidden_dim - self._num_hypernet_layers = num_hypernet_layers - self._hypernet_hidden_dim = hypernet_hidden_dim - self._n_agents = architecture._n_agents - self._architecture = architecture - self._agent_networks: Dict[str, snt.Module] = {} - - def _create_mixing_layer(self, name: str = "mixing") -> snt.Module: - """Modify and return system architecture given mixing structure.""" - state_specs = self._environment_spec.get_extra_specs() - state_specs = state_specs["s_t"] - - q_value_dim = tf.TensorSpec(self._n_agents) - - # Implement method from base class - self._mixed_network = MonotonicMixingNetwork( - n_agents=self._n_agents, - qmix_hidden_dim=self._qmix_hidden_dim, - num_hypernet_layers=self._num_hypernet_layers, - hypernet_hidden_dim=self._hypernet_hidden_dim, - name=name, - ) - - tf2_utils.create_variables(self._mixed_network, [q_value_dim, state_specs]) - return self._mixed_network - - def create_system(self) -> Dict[str, Dict[str, snt.Module]]: - # Implement method from base class - self._agent_networks[ - "agent_networks" - ] = self._architecture.create_actor_variables() - self._agent_networks["mixing"] = self._create_mixing_layer(name="mixing") - self._agent_networks["target_mixing"] = self._create_mixing_layer( - name="target_mixing" - ) - - return self._agent_networks diff --git a/mava/components/tf/modules/stabilising/__init__.py b/mava/components/tf/modules/stabilising/__init__.py index 9d6685ec8..6ccea8f93 100644 --- a/mava/components/tf/modules/stabilising/__init__.py +++ b/mava/components/tf/modules/stabilising/__init__.py @@ -12,9 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -"""MARL system stabilising modules.""" - -from mava.components.tf.modules.stabilising.base import BaseStabilisationModule -from mava.components.tf.modules.stabilising.fingerprints import FingerPrintStabalisation +"""MARL experience replay stabilising modules.""" \ No newline at end of file diff --git a/mava/components/tf/modules/stabilising/base.py b/mava/components/tf/modules/stabilising/base.py deleted file mode 100644 index a6b4602d9..000000000 --- a/mava/components/tf/modules/stabilising/base.py +++ /dev/null @@ -1,33 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from typing import Dict - -import sonnet as snt - -"""Base stabilising interface for multi-agent RL systems""" - - -class BaseStabilisationModule: - """Base class for MARL stabilising. - Objects which implement this interface provide a set of functions - to create systems that can stabilise training for agents in a - multi-agent RL system. - """ - - @abc.abstractmethod - def create_system(self) -> Dict[str, Dict[str, snt.Module]]: - """Create system architecture with stabilisation.""" diff --git a/mava/components/tf/modules/stabilising/fingerprints.py b/mava/components/tf/modules/stabilising/fingerprints.py index 4c91df65b..33d599bb2 100644 --- a/mava/components/tf/modules/stabilising/fingerprints.py +++ b/mava/components/tf/modules/stabilising/fingerprints.py @@ -12,65 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Experience replay stabilisation with fingerprints""" - -"""Stabilising for multi-agent RL systems""" -from typing import Dict - -import sonnet as snt -import tensorflow as tf -from acme.tf import utils as tf2_utils - -from mava.components.tf.architectures import DecentralisedValueActor -from mava.components.tf.modules.stabilising import BaseStabilisationModule - - -class FingerPrintStabalisation(BaseStabilisationModule): - """Multi-agent stabalisation architecture.""" - - def __init__( - self, - architecture: DecentralisedValueActor, - ) -> None: - self._architecture = architecture - self._fingerprint_spec = tf.ones((2,), dtype="float32") - - def create_actor_variables_with_fingerprints( - self, - ) -> Dict[str, Dict[str, snt.Module]]: - - actor_networks: Dict[str, Dict[str, snt.Module]] = { - "values": {}, - "target_values": {}, - } - - # get actor specs - actor_obs_specs = self._architecture._get_actor_specs() - - # create policy variables for each agent - for agent_key in self._architecture._agents: - agent_net_key = self._architecture._agent_net_keys[agent_key] - obs_spec = actor_obs_specs[agent_key] - - # Create variables for value and policy networks. - tf2_utils.create_variables( - self._architecture._value_networks[agent_net_key], - [obs_spec, self._fingerprint_spec], - ) - - # create target value network variables - tf2_utils.create_variables( - self._architecture._target_value_networks[agent_net_key], - [obs_spec, self._fingerprint_spec], - ) - - actor_networks["values"] = self._architecture._value_networks - actor_networks["target_values"] = self._architecture._target_value_networks - - return actor_networks - - def create_system( - self, - ) -> Dict[str, Dict[str, snt.Module]]: - networks = self.create_actor_variables_with_fingerprints() - return networks +# TODO (Claude) implement fingerprints for new MADQN system. \ No newline at end of file diff --git a/mava/components/tf/networks/__init__.py b/mava/components/tf/networks/__init__.py index 301251cd1..304bac090 100644 --- a/mava/components/tf/networks/__init__.py +++ b/mava/components/tf/networks/__init__.py @@ -20,7 +20,6 @@ from acme.tf.networks.noise import ClippedGaussian from acme.tf.networks.rescaling import ClipToSpec, RescaleToSpec, TanhToSpec -from mava.components.tf.networks.additive import AdditiveMixingNetwork from mava.components.tf.networks.communication import CommunicationNetwork from mava.components.tf.networks.continuous import ( LayerNormAndResidualMLP, @@ -28,9 +27,7 @@ NearZeroInitializedLinear, ) from mava.components.tf.networks.convolution import Conv1DNetwork -from mava.components.tf.networks.fingerprints import ObservationNetworkWithFingerprint from mava.components.tf.networks.mad4pg import ( DiscreteValuedDistribution, DiscreteValuedHead, -) -from mava.components.tf.networks.monotonic import MonotonicMixingNetwork +) \ No newline at end of file diff --git a/mava/components/tf/networks/additive.py b/mava/components/tf/networks/additive.py deleted file mode 100644 index b00bf328c..000000000 --- a/mava/components/tf/networks/additive.py +++ /dev/null @@ -1,30 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sonnet as snt -import tensorflow as tf - - -class AdditiveMixingNetwork(snt.Module): - """Multi-agent monotonic mixing architecture.""" - - def __init__(self, name: str = "mixing") -> None: - """Initializes the mixer.""" - super(AdditiveMixingNetwork, self).__init__(name=name) - - def __call__(self, q_values: tf.Tensor) -> tf.Tensor: - """Monotonic mixing logic.""" - # return tf.math.reduce_sum(q_values, axis=1) - return tf.math.reduce_sum(q_values, axis=1) diff --git a/mava/components/tf/networks/fingerprints.py b/mava/components/tf/networks/fingerprints.py deleted file mode 100644 index 741e6e683..000000000 --- a/mava/components/tf/networks/fingerprints.py +++ /dev/null @@ -1,52 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sonnet module that takes two inputs -[observation, fingerprint]""" - -import sonnet as snt -import tensorflow as tf - - -class ObservationNetworkWithFingerprint(snt.Module): - """Sonnet module that takes two inputs - [observation, fingerprint] and returns an observation - embedding and concatenates the fingerprint to the - embedding. Downstream layers can then be trained - on the embedding+fingerprint.""" - - def __init__( - self, - observation_network: snt.Module, - ) -> None: - """Initializes network. - Args: - observation_network: ... - """ - super(ObservationNetworkWithFingerprint, self).__init__() - self._observation_network = observation_network - self._flatten_layer = tf.keras.layers.Flatten() - - def __call__( - self, - obs: tf.Tensor, - fingerprint: tf.Tensor, - ) -> tf.Tensor: - - hidden = self._observation_network(obs) - flatten = self._flatten_layer(hidden) - hidden_with_fingerprint = tf.concat([flatten, fingerprint], axis=1) - - return hidden_with_fingerprint diff --git a/mava/components/tf/networks/hypernetwork.py b/mava/components/tf/networks/hypernetwork.py deleted file mode 100644 index e72b526d2..000000000 --- a/mava/components/tf/networks/hypernetwork.py +++ /dev/null @@ -1,99 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -import sonnet as snt -import tensorflow as tf -from tensorflow import Tensor - - -class HyperNetwork(snt.Module): - def __init__( - self, - qmix_hidden_dim: int, # qmix_hidden_dim - n_agents: int, - num_hypernet_layers: int = 2, - hypernet_hidden_dim: int = 0, # qmix_hidden_dim - ): - """Initializes the mixer. - Args: - qmix_hidden_dim: Mixing layers hidden dimensions. - i.e. What size the mixing network takes as input. - num_hypernet_layers: Number of hypernetwork layers. Currently 1 or 2. - hypernet_hidden_dim: The number of nodes in the hypernetwork hidden - layer. Relevant for num_hypernet_layers > 1. - """ - super(HyperNetwork, self).__init__() - self._qmix_hidden_dim = qmix_hidden_dim - self._num_hypernet_layers = num_hypernet_layers - self._n_agents = n_agents - - # Let the user define the hidden dim but default it to qmix_hidden_dim. - if hypernet_hidden_dim == 0: - self._hypernet_hidden_dim = qmix_hidden_dim - else: - self._hypernet_hidden_dim = hypernet_hidden_dim - - # Set up hypernetwork configuration - if self._num_hypernet_layers == 1: - self.hyper_w1 = snt.nets.MLP( - output_sizes=[self._qmix_hidden_dim * self._n_agents] - ) - self.hyper_w2 = snt.nets.MLP(output_sizes=[self._qmix_hidden_dim]) - - # Default - elif self._num_hypernet_layers == 2: - self.hyper_w1 = snt.nets.MLP( - output_sizes=[ - self._hypernet_hidden_dim, - self._qmix_hidden_dim * self._n_agents, - ] - ) - self.hyper_w2 = snt.nets.MLP( - output_sizes=[self._hypernet_hidden_dim, self._qmix_hidden_dim] - ) - - # State dependent bias for hidden layer - self.hyper_b1 = snt.nets.MLP(output_sizes=[self._qmix_hidden_dim]) - self.hyper_b2 = snt.nets.MLP(output_sizes=[self._qmix_hidden_dim, 1]) - - def __call__(self, states: Tensor) -> Dict[str, float]: # [batch_size=B, state_dim] - w1 = tf.abs( - self.hyper_w1(states) - ) # [B, qmix_hidden_dim] = [B, qmix_hidden_dim] - w1 = tf.reshape( - w1, - (-1, self._n_agents, self._qmix_hidden_dim), - ) # [B, n_agents, qmix_hidden_dim] - - b1 = self.hyper_b1(states) # [B, qmix_hidden_dim] = [B, qmix_hidden_dim] - b1 = tf.reshape(b1, [-1, 1, self._qmix_hidden_dim]) # [B, 1, qmix_hidden_dim] - - w2 = tf.abs(self.hyper_w2(states)) - w2 = tf.reshape( - w2, shape=(-1, self._qmix_hidden_dim, 1) - ) # [B, qmix_hidden_dim, 1] - - b2 = self.hyper_b2(states) # [B, 1] - b2 = tf.reshape(b2, shape=(-1, 1, 1)) # [B, 1, 1] - - hyperparams = {} - hyperparams["w1"] = w1 # [B, n_agents, qmix_hidden_dim] - hyperparams["b1"] = b1 # [B, 1, qmix_hidden_dim] - hyperparams["w2"] = w2 # [B, qmix_hidden_dim] - hyperparams["b2"] = b2 # [B, 1] - - return hyperparams diff --git a/mava/components/tf/networks/monotonic.py b/mava/components/tf/networks/monotonic.py deleted file mode 100644 index b98f4e90a..000000000 --- a/mava/components/tf/networks/monotonic.py +++ /dev/null @@ -1,86 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Code inspired by PyMARL framework implementation -# https://github.com/oxwhirl/pymarl/blob/master/src/modules/mixers/qmix.py - -"""Mixing for multi-agent RL systems""" - - -import sonnet as snt -import tensorflow as tf - -from mava.components.tf.networks.hypernetwork import HyperNetwork - - -class MonotonicMixingNetwork(snt.Module): - """Multi-agent monotonic mixing architecture. - This is the component which can be used to add monotonic mixing to an underlying - agent architecture. It currently supports generalised monotonic mixing using - hypernetworks (1 or 2 layers) for control of decomposition parameters (QMix).""" - - def __init__( - self, - n_agents: int, - name: str = "mixing", - qmix_hidden_dim: int = 64, - num_hypernet_layers: int = 2, - hypernet_hidden_dim: int = 0, - ) -> None: - """Initializes the mixer. - Args: - state_shape: The state shape as defined by the environment. - n_agents: The number of agents (i.e. Q-values) to mix. - qmix_hidden_dim: Mixing layers hidden dimensions. - num_hypernet_layers: Number of hypernetwork layers. Currently 1 or 2. - hypernet_hidden_dim: The number of nodes in the hypernetwork hidden - layer. Relevant for num_hypernet_layers > 1. - """ - super(MonotonicMixingNetwork, self).__init__(name=name) - self._n_agents = n_agents - self._qmix_hidden_dim = qmix_hidden_dim - self._num_hypernet_layers = num_hypernet_layers - self._hypernet_hidden_dim = hypernet_hidden_dim - - # Create hypernetwork - self._hypernetworks = HyperNetwork( - self._qmix_hidden_dim, - self._n_agents, - self._num_hypernet_layers, - self._hypernet_hidden_dim, - ) - - def __call__( - self, - q_values: tf.Tensor, # [batch_size, n_agents] - states: tf.Tensor, # [batch_size, state_dim] - ) -> tf.Tensor: - """Monotonic mixing logic.""" - - # Create hypernetwork - self._hyperparams = self._hypernetworks(states) - - # Extract hypernetwork layers - # TODO: make more general -> this assumes two layer hypernetwork - w1 = self._hyperparams["w1"] # [B, n_agents, qmix_hidden_dim] - b1 = self._hyperparams["b1"] # [B, 1, qmix_hidden_dim] - w2 = self._hyperparams["w2"] # [B, qmix_hidden_dim, 1] - b2 = self._hyperparams["b2"] # [B, 1, 1] - - # ELU -> Exp. linear unit - hidden = tf.nn.elu(tf.matmul(q_values, w1) + b1) # [B, 1, qmix_hidden_dim] - - # Qtot: [B, 1, 1] - return tf.matmul(hidden, w2) + b2 diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index a4533d9dc..4ed085149 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -453,11 +453,19 @@ def observe_first( ) if self._store_recurrent_state: + # Core states numpy_states = { agent: tf2_utils.to_numpy_squeeze(_state) for agent, _state in self._states.items() } - extras.update({"core_states": numpy_states}) + + extras.update( + { + "core_states": numpy_states, + "zero_padding_mask": np.array(1) + } + ) + extras["network_int_keys"] = self._network_int_keys_extras self._adder.add_first(timestep, extras) @@ -483,7 +491,14 @@ def observe( agent: tf2_utils.to_numpy_squeeze(_state) for agent, _state in self._states.items() } - next_extras.update({"core_states": numpy_states}) + + next_extras.update( + { + "core_states": numpy_states, + "zero_padding_mask": np.array(1) + } + ) + next_extras["network_int_keys"] = self._network_int_keys_extras self._adder.add(actions, next_timestep, next_extras) # type: ignore diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index fc67b580b..4f0d38ecf 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -18,6 +18,7 @@ import functools from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Mapping +import numpy as np import acme import dm_env import launchpad as lp @@ -442,7 +443,7 @@ def _get_extra_specs(self) -> Any: networks["values"][agent_net_key].initial_state(1) ), ) - return {"core_states": core_state_specs} + return {"core_states": core_state_specs, "zero_padding_mask": np.array(1)} def replay(self) -> Any: """Step counter diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 397cee0bb..ea8b6b38c 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -717,9 +717,10 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: q_tm1, a_tm1, r_t, discount * d_t, q_t_value, q_t_selector ) - # TODO zero padding mask - - self.value_losses[agent] = tf.reduce_mean(value_loss) + # Zero-padding mask + zero_padding_mask = tf.cast(extras["zero_padding_mask"], dtype=value_loss.dtype)[:-1] + masked_loss = value_loss * zero_padding_mask + self.value_losses[agent] = tf.reduce_sum(masked_loss) / tf.reduce_sum(zero_padding_mask) self.tape = tape diff --git a/mava/systems/tf/qmix/README.md b/mava/systems/tf/qmix/README.md deleted file mode 100644 index 0fa390943..000000000 --- a/mava/systems/tf/qmix/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# QMIX (Q-value function factorisation) - -An implementaiton of the QMIX MARL system ([Rashid et al., 2018]). QMIX is based on the idea of factorising the joint Q-value function for a team of agents and learning the weightings for each component using a monotonic mixing network whose weights are itself learned using a hypernetwork. 🔺 NOTE: our current implementation of QMIX has been not able to reproduce results demonstrated in the original paper. - -

- -

- -[Rashid et al., 2018]: https://arxiv.org/pdf/1803.11485 diff --git a/mava/systems/tf/qmix/__init__.py b/mava/systems/tf/qmix/__init__.py deleted file mode 100644 index 47039da20..000000000 --- a/mava/systems/tf/qmix/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mava.systems.tf.qmix.execution import QMIXFeedForwardExecutor -from mava.systems.tf.qmix.networks import make_default_networks -from mava.systems.tf.qmix.system import QMIX -from mava.systems.tf.qmix.training import QMIXTrainer diff --git a/mava/systems/tf/qmix/builder.py b/mava/systems/tf/qmix/builder.py deleted file mode 100644 index 990d419f2..000000000 --- a/mava/systems/tf/qmix/builder.py +++ /dev/null @@ -1,179 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""QMIX system builder implementation.""" - -import dataclasses -from typing import Any, Dict, Iterator, Optional, Type - -import reverb -import sonnet as snt -from acme.utils import counting - -from mava import core, types -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.components.tf.modules.mixing import MonotonicMixing -from mava.components.tf.modules.stabilising import FingerPrintStabalisation -from mava.systems.tf.madqn.builder import MADQNBuilder, MADQNConfig -from mava.systems.tf.qmix import execution, training -from mava.wrappers import DetailedTrainerStatistics - - -@dataclasses.dataclass -class QMIXConfig(MADQNConfig): - """Configuration options for the QMIX system. - - Args: - environment_spec: description of the action and observation spaces etc. for - each agent in the system. - epsilon_min: final minimum value for epsilon at the end of a decay schedule. - epsilon_decay: the rate at which epislon decays. - shared_weights: boolean indicating whether agents should share weights. - target_update_period: number of learner steps to perform before updating - the target networks. - executor_variable_update_period: the rate at which executors sync their - paramters with the trainer. - max_gradient_norm: value to specify the maximum clipping value for the gradient - norm during optimization. - min_replay_size: minimum replay size before updating. - max_replay_size: maximum replay size. - samples_per_insert: number of samples to take from replay for every insert - that is made. - prefetch_size: size to prefetch from replay. - batch_size: batch size for updates. - n_step: number of steps to include prior to boostrapping. - sequence_length: recurrent sequence rollout length. - period: consecutive starting points for overlapping rollouts across a sequence. - discount: discount to use for TD updates. - checkpoint: boolean to indicate whether to checkpoint models. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - optimizer: type of optimizer to use for updating the parameters of models. - replay_table_name: string indicating what name to give the replay table. - checkpoint_subpath: subdirectory specifying where to store checkpoints. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - -class QMIXBuilder(MADQNBuilder): - """Builder for QMIX which constructs individual components of the system.""" - - def __init__( - self, - config: QMIXConfig, - trainer_fn: Type[training.QMIXTrainer] = training.QMIXTrainer, - executor_fn: Type[core.Executor] = execution.QMIXFeedForwardExecutor, - mixer: Type[MonotonicMixing] = MonotonicMixing, - extra_specs: Dict[str, Any] = {}, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, - ) -> None: - """Initialise the system. - - Args: - config (QMIXConfig): system configuration specifying hyperparameters and - additional information for constructing the system. - trainer_fn (Type[training.QMIXTrainer], optional): Trainer function, of a - correpsonding type to work with the selected system architecture. - Defaults to training.QMIXTrainer. - executor_fn (Type[core.Executor], optional): Executor function, of a - corresponding type to work with the selected system architecture. - Defaults to execution.QMIXFeedForwardExecutor. - mixer (Type[MonotonicMixing], optional): mixer module type, e.g. additive or - monotonic mixing. Defaults to MonotonicMixing. - extra_specs (Dict[str, Any], optional): defines the specifications of extra - information used by the system. Defaults to {}. - replay_stabilisation_fn (Optional[Type[FingerPrintStabalisation]], - optional): optional function to stabilise experience replay. Defaults - to None. - """ - - super(QMIXBuilder, self).__init__( - config=config, - trainer_fn=trainer_fn, - executor_fn=executor_fn, - extra_specs=extra_specs, - replay_stabilisation_fn=replay_stabilisation_fn, - ) - self._mixer = mixer - - def make_trainer( - self, - networks: Dict[str, Dict[str, snt.Module]], - dataset: Iterator[reverb.ReplaySample], - counter: Optional[counting.Counter] = None, - logger: Optional[types.NestedLogger] = None, - communication_module: Optional[BaseCommunicationModule] = None, - replay_client: Optional[reverb.TFClient] = None, - ) -> core.Trainer: - """Create a trainer instance. - - Args: - networks (Dict[str, Dict[str, snt.Module]]): system networks. - dataset (Iterator[reverb.ReplaySample]): dataset iterator to feed data to - the trainer networks. - counter (Optional[counting.Counter], optional): a Counter which allows for - recording of counts, e.g. trainer steps. Defaults to None. - logger (Optional[types.NestedLogger], optional): Logger object for logging - metadata.. Defaults to None. - communication_module (BaseCommunicationModule): module to enable - agent communication. Defaults to None. - replay_client (reverb.TFClient): Used for importance sampling. - Not implemented yet. - - Returns: - core.Trainer: system trainer, that uses the collected data from the - executors to update the parameters of the agent networks in the system. - """ - - agents = self._config.environment_spec.get_agent_ids() - agent_types = self._config.environment_spec.get_agent_types() - - q_networks = networks["agent_networks"]["values"] - target_q_networks = networks["agent_networks"]["target_values"] - - mixing_network = networks["mixing"] - target_mixing_network = networks["target_mixing"] - - # Check if we should use fingerprints - fingerprint = True if self._replay_stabiliser_fn is not None else False - - # The learner updates the parameters (and initializes them). - trainer = self._trainer_fn( # type:ignore - agents=agents, - agent_types=agent_types, - discount=self._config.discount, - q_networks=q_networks, - target_q_networks=target_q_networks, - mixing_network=mixing_network, - target_mixing_network=target_mixing_network, - agent_net_keys=self._config.agent_net_keys, - optimizer=self._config.optimizer, - target_update_period=self._config.target_update_period, - max_gradient_norm=self._config.max_gradient_norm, - communication_module=communication_module, - dataset=dataset, - counter=counter, - fingerprint=fingerprint, - logger=logger, - checkpoint_minute_interval=self._config.checkpoint_minute_interval, - checkpoint=self._config.checkpoint, - checkpoint_subpath=self._config.checkpoint_subpath, - learning_rate_scheduler_fn=self._config.learning_rate_scheduler_fn, - ) - - trainer = DetailedTrainerStatistics(trainer) # type:ignore - - return trainer diff --git a/mava/systems/tf/qmix/execution.py b/mava/systems/tf/qmix/execution.py deleted file mode 100644 index 56209ede0..000000000 --- a/mava/systems/tf/qmix/execution.py +++ /dev/null @@ -1,81 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""QMIX system executor implementation.""" - -from typing import Dict, Optional - -import sonnet as snt -from acme.tf import variable_utils as tf2_variable_utils - -from mava import adders -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor -from mava.systems.tf.madqn.training import MADQNTrainer - - -class QMIXFeedForwardExecutor(MADQNFeedForwardExecutor): - """A feed-forward executor. - An executor based on a feed-forward policy for each agent in the system. - """ - - def __init__( - self, - q_networks: Dict[str, snt.Module], - action_selectors: Dict[str, snt.Module], - trainer: MADQNTrainer, - agent_net_keys: Dict[str, str], - adder: Optional[adders.ParallelAdder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - communication_module: Optional[BaseCommunicationModule] = None, - fingerprint: bool = False, - evaluator: bool = False, - interval: Optional[dict] = None, - ): - """Initialise the system executor - - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - trainer (MADQNTrainer, optional): system trainer. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - adder (Optional[adders.ParallelAdder], optional): adder which sends data - to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. Defaults to None. - fingerprint (bool, optional): whether to use fingerprint stabilisation to - stabilise experience replay. Defaults to False. - evaluator (bool, optional): whether the executor will be used for - evaluation. Defaults to False. - interval: interval that evaluations are run at. - """ - - super(QMIXFeedForwardExecutor, self).__init__( - q_networks=q_networks, - action_selectors=action_selectors, - agent_net_keys=agent_net_keys, - adder=adder, - variable_client=variable_client, - communication_module=communication_module, - fingerprint=fingerprint, - trainer=trainer, - evaluator=evaluator, - interval=interval, - ) diff --git a/mava/systems/tf/qmix/networks.py b/mava/systems/tf/qmix/networks.py deleted file mode 100644 index 913a36d7e..000000000 --- a/mava/systems/tf/qmix/networks.py +++ /dev/null @@ -1,66 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, Mapping, Optional, Sequence, Union - -from acme import types - -from mava import specs as mava_specs -from mava.systems.tf.madqn.networks import ( - make_default_networks as make_default_networks_madqn, -) -from mava.utils.enums import ArchitectureType, Network - - -# Default networks for qmix -def make_default_networks( - environment_spec: mava_specs.MAEnvironmentSpec, - agent_net_keys: Dict[str, str], - policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (128, 128), - archecture_type: ArchitectureType = ArchitectureType.feedforward, - network_type: Network = Network.mlp, - fingerprints: bool = False, - seed: Optional[int] = None, -) -> Mapping[str, types.TensorTransformation]: - """Default networks for qmix. - - Args: - environment_spec (mava_specs.MAEnvironmentSpec): description of the action and - observation spaces etc. for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): - size of policy networks. Defaults to (128,128). - archecture_type (ArchitectureType, optional): archecture used for - agent networks. Can be feedforward or recurrent. - Defaults to ArchitectureType.recurrent. - network_type (Network, optional): Agent network type. Can be mlp, - atari_dqn_network or coms_network. Defaults to Network.coms_network. - fingerprints (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - seed (int, optional): random seed for network initialization. - - Returns: - Mapping[str, types.TensorTransformation]: returned agent networks. - """ - - return make_default_networks_madqn( - environment_spec=environment_spec, - policy_networks_layer_sizes=policy_networks_layer_sizes, - agent_net_keys=agent_net_keys, - archecture_type=archecture_type, - network_type=network_type, - fingerprints=fingerprints, - seed=seed, - ) diff --git a/mava/systems/tf/qmix/system.py b/mava/systems/tf/qmix/system.py deleted file mode 100644 index 3700a7ce7..000000000 --- a/mava/systems/tf/qmix/system.py +++ /dev/null @@ -1,362 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""QMIX system implementation.""" - -from typing import Any, Callable, Dict, Optional, Type, Union - -import dm_env -import reverb -import sonnet as snt -from acme import specs as acme_specs -from acme.utils import counting - -import mava -from mava import core -from mava import specs as mava_specs -from mava.components.tf.architectures import DecentralisedValueActor -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.components.tf.modules.mixing import MonotonicMixing -from mava.components.tf.modules.stabilising import FingerPrintStabalisation -from mava.environment_loop import ParallelEnvironmentLoop -from mava.systems.tf import executors -from mava.systems.tf.madqn.system import MADQN -from mava.systems.tf.qmix import builder, execution, training -from mava.types import EpsilonScheduler -from mava.utils.loggers import MavaLogger - - -# TODO Implement recurrent QMIX -class QMIX(MADQN): - """QMIX system.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], - exploration_scheduler_fn: Union[ - EpsilonScheduler, - Dict[str, EpsilonScheduler], - Dict[str, Dict[str, EpsilonScheduler]], - ], - logger_factory: Callable[[str], MavaLogger] = None, - architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, - trainer_fn: Type[training.QMIXTrainer] = training.QMIXTrainer, - executor_fn: Type[core.Executor] = execution.QMIXFeedForwardExecutor, - mixer: Type[MonotonicMixing] = MonotonicMixing, - communication_module: Type[BaseCommunicationModule] = None, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, - num_executors: int = 1, - num_caches: int = 0, - environment_spec: mava_specs.MAEnvironmentSpec = None, - shared_weights: bool = True, - agent_net_keys: Dict[str, str] = {}, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - sequence_length: int = 20, - importance_sampling_exponent: Optional[float] = None, - max_priority_weight: float = 0.9, - period: int = 20, - max_gradient_norm: float = None, - discount: float = 0.99, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( - learning_rate=1e-4 - ), - target_update_period: int = 100, - executor_variable_update_period: int = 1000, - max_executor_steps: int = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - checkpoint_minute_interval: int = 5, - logger_config: Dict = {}, - train_loop_fn: Callable = ParallelEnvironmentLoop, - eval_loop_fn: Callable = ParallelEnvironmentLoop, - train_loop_fn_kwargs: Dict = {}, - eval_loop_fn_kwargs: Dict = {}, - qmix_hidden_dim: int = 32, - num_hypernet_layers: int = 1, - hypernet_hidden_dim: int = 32, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - seed: Optional[int] = None, - evaluator_interval: Optional[dict] = None, - ): - """Initialise the system - - Args: - environment_factory (Callable[[bool], dm_env.Environment]): function to - instantiate an environment. - network_factory (Callable[[acme_specs.BoundedArray], - Dict[str, snt.Module]]): function to instantiate system networks. - logger_factory (Callable[[str], MavaLogger], optional): function to - instantiate a system logger. Defaults to None. - architecture (Type[DecentralisedValueActor], optional): system architecture, - e.g. decentralised or centralised. Defaults to DecentralisedValueActor. - trainer_fn (Type[training.QMIXTrainer], optional): training type associated - with executor and architecture, e.g. centralised training. Defaults to - training.QMIXTrainer. - executor_fn (Type[core.Executor], optional): executor type, e.g. - feedforward or recurrent. Defaults to execution.QMIXFeedForwardExecutor. - mixer (Type[MonotonicMixing], optional): mixer module type, e.g. additive or - monotonic mixing. Defaults to MonotonicMixing. - communication_module (Type[BaseCommunicationModule], optional): - module for enabling communication protocols between agents. Defaults to - None. - exploration_scheduler_fn (Type[ LinearExplorationScheduler ], optional): - function specifying a decaying scheduler for epsilon exploration. - See mava/systems/tf/madqn/system.py for details. - replay_stabilisation_fn (Optional[Type[FingerPrintStabalisation]], - optional): replay buffer stabilisaiton function, e.g. fingerprints. - Defaults to None. - num_executors (int, optional): number of executor processes to run in - parallel. Defaults to 1. - num_caches (int, optional): number of trainer node caches. Defaults to 0. - environment_spec (mava_specs.MAEnvironmentSpec, optional): description of - the action, observation spaces etc. for each agent in the system. - Defaults to None. - shared_weights (bool, optional): whether agents should share weights or not. - When agent_net_keys are provided the value of shared_weights is ignored. - Defaults to True. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - batch_size (int, optional): sample batch size for updates. Defaults to 256. - prefetch_size (int, optional): size to prefetch from replay. Defaults to 4. - min_replay_size (int, optional): minimum replay size before updating. - Defaults to 1000. - max_replay_size (int, optional): maximum replay size. Defaults to 1000000. - samples_per_insert (Optional[float], optional): number of samples to take - from replay for every insert that is made. Defaults to 32.0. - n_step (int, optional): number of steps to include prior to boostrapping. - Defaults to 5. - sequence_length (int, optional): recurrent sequence rollout length. Defaults - to 20. - period (int, optional): The period with which we add sequences. See `period` - in `acme.SequenceAdder.period` for more info. Defaults to 20. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - discount (float, optional): discount factor to use for TD updates. Defaults - to 0.99. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]], optional): - type of optimizer to use to update network parameters. Defaults to - snt.optimizers.Adam( learning_rate=1e-4 ). - target_update_period (int, optional): number of steps before target - networks are updated. Defaults to 100. - executor_variable_update_period (int, optional): number of steps before - updating executor variables from the variable source. Defaults to 1000. - max_executor_steps (int, optional): maximum number of steps and executor - can in an episode. Defaults to None. - checkpoint (bool, optional): whether to checkpoint models. Defaults to - False. - checkpoint_subpath (str, optional): subdirectory specifying where to store - checkpoints. Defaults to "~/mava/". - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - logger_config (Dict, optional): additional configuration settings for the - logger factory. Defaults to {}. - train_loop_fn (Callable, optional): function to instantiate a train loop. - Defaults to ParallelEnvironmentLoop. - eval_loop_fn (Callable, optional): function to instantiate an evaluation - loop. Defaults to ParallelEnvironmentLoop. - train_loop_fn_kwargs (Dict, optional): possible keyword arguments to send - to the training loop. Defaults to {}. - eval_loop_fn_kwargs (Dict, optional): possible keyword arguments to send to - the evaluation loop. Defaults to {}. - qmix_hidden_dim (int, optional): mixing network hidden dimension. Defaults - to 32. - num_hypernet_layers (int, optional): number of layers for hypernetwork. - Defaults to 1. - hypernet_hidden_dim (int, optional): hypernetwork hidden dimension. Defaults - to 32. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - seed: seed for reproducible sampling (for epsilon greedy action selection). - evaluator_interval: An optional condition that is used to - evaluate/test system performance after [evaluator_interval] - condition has been met. If None, evaluation will - happen at every timestep. - E.g. to evaluate a system after every 100 executor episodes, - evaluator_interval = {"executor_episodes": 100}. - """ - - self._mixer = mixer - self._qmix_hidden_dim = qmix_hidden_dim - self._num_hypernet_layers = num_hypernet_layers - self._hypernet_hidden_dim = hypernet_hidden_dim - - super(QMIX, self).__init__( - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - architecture=architecture, - trainer_fn=trainer_fn, - communication_module=communication_module, - executor_fn=executor_fn, - replay_stabilisation_fn=replay_stabilisation_fn, - num_executors=num_executors, - num_caches=num_caches, - environment_spec=environment_spec, - agent_net_keys=agent_net_keys, - shared_weights=shared_weights, - batch_size=batch_size, - prefetch_size=prefetch_size, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - n_step=n_step, - sequence_length=sequence_length, - period=period, - discount=discount, - optimizer=optimizer, - target_update_period=target_update_period, - executor_variable_update_period=executor_variable_update_period, - max_executor_steps=max_executor_steps, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - checkpoint_minute_interval=checkpoint_minute_interval, - logger_config=logger_config, - train_loop_fn=train_loop_fn, - eval_loop_fn=eval_loop_fn, - train_loop_fn_kwargs=train_loop_fn_kwargs, - eval_loop_fn_kwargs=eval_loop_fn_kwargs, - evaluator_interval=evaluator_interval, - exploration_scheduler_fn=exploration_scheduler_fn, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - seed=seed, - ) - - if issubclass(executor_fn, executors.RecurrentExecutor): - extra_specs = self._get_extra_specs() - else: - extra_specs = {} - - self._builder = builder.QMIXBuilder( - builder.QMIXConfig( - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - executor_variable_update_period=executor_variable_update_period, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - n_step=n_step, - sequence_length=sequence_length, - importance_sampling_exponent=importance_sampling_exponent, - max_priority_weight=max_priority_weight, - period=period, - max_gradient_norm=max_gradient_norm, - checkpoint=checkpoint, - optimizer=optimizer, - checkpoint_subpath=checkpoint_subpath, - checkpoint_minute_interval=checkpoint_minute_interval, - evaluator_interval=evaluator_interval, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ), - trainer_fn=trainer_fn, - executor_fn=executor_fn, - extra_specs=extra_specs, - replay_stabilisation_fn=replay_stabilisation_fn, - mixer=mixer, - ) - - def trainer( - self, - replay: reverb.Client, - counter: counting.Counter, - ) -> mava.core.Trainer: - """System trainer - - Args: - replay (reverb.Client): replay data table to pull data from. - counter (counting.Counter): step counter object. - - Raises: - Exception: no option for recurrence (yet). - - Returns: - mava.core.Trainer: system trainer. - """ - - # Create the networks to optimize (online) - networks = self._network_factory( # type: ignore - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - ) - - # Create system architecture - architecture = self._architecture( - environment_spec=self._environment_spec, - value_networks=networks["q_networks"], - agent_net_keys=self._agent_net_keys, - ) - - # Fingerprint module - if self._builder._replay_stabiliser_fn is not None: - architecture = self._builder._replay_stabiliser_fn( # type: ignore - architecture - ) - - # Communication module - # NOTE: this is currently not expected to work with qmix - # since we do not have a recurrent version. - if self._communication_module_fn is not None: - raise Exception( - "QMIX currently does not support recurrence and \ - therefore cannot use a communication module." - ) - - # Mixing module - system_networks = self._mixer( - environment_spec=self._environment_spec, - architecture=architecture, - num_hypernet_layers=self._num_hypernet_layers, - qmix_hidden_dim=self._qmix_hidden_dim, - hypernet_hidden_dim=self._hypernet_hidden_dim, - ).create_system() - - # Create logger - trainer_logger_config = {} - if self._logger_config and "trainer" in self._logger_config: - trainer_logger_config = self._logger_config["trainer"] - trainer_logger = self._logger_factory( # type: ignore - "trainer", **trainer_logger_config - ) - - dataset = self._builder.make_dataset_iterator(replay) - counter = counting.Counter(counter, "trainer") - - return self._builder.make_trainer( - networks=system_networks, - dataset=dataset, - counter=counter, - communication_module=None, - logger=trainer_logger, - ) - - def build(self, name: str = "qmix") -> Any: - """Build the distributed system as a graph program. - - Args: - name (str, optional): system name. Defaults to "qmix". - - Returns: - Any: graph program for distributed system training. - """ - return super().build(name=name) diff --git a/mava/systems/tf/qmix/training.py b/mava/systems/tf/qmix/training.py deleted file mode 100644 index 63a457abf..000000000 --- a/mava/systems/tf/qmix/training.py +++ /dev/null @@ -1,305 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""QMIX system trainer implementation.""" - -from typing import Any, Callable, Dict, List, Optional, Sequence - -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -from acme.tf import utils as tf2_utils -from acme.utils import counting, loggers -from trfl.indexing_ops import batched_index - -from mava import types as mava_types -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.training import MADQNTrainer -from mava.utils import training_utils as train_utils - -train_utils.set_growing_gpu_memory() - - -class QMIXTrainer(MADQNTrainer): - """QMIX trainer. - This is the trainer component of a QMIX system. i.e. it takes a dataset as input - and implements update functionality to learn from this dataset. - """ - - def __init__( - self, - agents: List[str], - agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], - mixing_network: snt.Module, - target_mixing_network: snt.Module, - target_update_period: int, - dataset: tf.data.Dataset, - optimizer: snt.Optimizer, - discount: float, - agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, - communication_module: Optional[BaseCommunicationModule] = None, - max_gradient_norm: float = None, - counter: counting.Counter = None, - fingerprint: bool = False, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - ) -> None: - """Initialise QMIX trainer - - Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - mixing_network (snt.Module): mixing networks learning factorised q-value - weights. - target_mixing_network (snt.Module): target mixing networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - communication_module (BaseCommunicationModule): module for communication - between agents. Defaults to None. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - counter (counting.Counter, optional): step counter object. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - self._mixing_network = mixing_network - self._target_mixing_network = target_mixing_network - self._optimizer = optimizer - - super(QMIXTrainer, self).__init__( - agents=agents, - agent_types=agent_types, - q_networks=q_networks, - target_q_networks=target_q_networks, - target_update_period=target_update_period, - dataset=dataset, - optimizer=optimizer, - discount=discount, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - communication_module=communication_module, - max_gradient_norm=max_gradient_norm, - counter=counter, - fingerprint=fingerprint, - logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) - - # Checkpoint the mixing networks - # TODO add checkpointing for mixing networks - - def _update_target_networks(self) -> None: - """Sync the target network parameters with the latest online network - parameters""" - - # Update target networks (incl. mixing networks). - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for key in self.unique_net_keys: - online_variables = [ - *self._q_networks[key].variables, - ] - target_variables = [ - *self._target_q_networks[key].variables, - ] - - # Make online -> target network update ops. - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - - # NOTE These shouldn't really be in the agent for loop. - online_variables = [*self._mixing_network.variables] - target_variables = [*self._target_mixing_network.variables] - - # Make online -> target network update ops. - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - - self._num_steps.assign_add(1) - - @tf.function - def _step( - self, - ) -> Dict[str, Dict[str, Any]]: - """Trainer forward and backward passes.""" - - # Update the target networks - self._update_target_networks() - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - - self._forward(inputs) - - self._backward() - - # Log losses per agent - return {agent: {"policy_loss": self.loss} for agent in self._agents} - - def _forward(self, inputs: reverb.ReplaySample) -> None: - """Trainer forward pass - - Args: - inputs (Any): input data from the data table (transitions) - """ - - # Unpack input data as follows: - # o_tm1 = dictionary of observations one for each agent - # a_tm1 = dictionary of actions taken from obs in o_tm1 - # e_tm1 [Optional] = extra data that the agents persist in replay. - # r_t = dictionary of rewards or rewards sequences - # (if using N step transitions) ensuing from actions a_tm1 - # d_t = environment discount ensuing from actions a_tm1. - # This discount is applied to future rewards after r_t. - # o_t = dictionary of next observations or next observation sequences - # e_t = [Optional] = extra data that the agents persist in replay. - trans = mava_types.Transition(*inputs.data) - - o_tm1, o_t, a_tm1, r_t, d_t, e_tm1, e_t = ( - trans.observations, - trans.next_observations, - trans.actions, - trans.rewards, - trans.discounts, - trans.extras, - trans.next_extras, - ) - - s_tm1 = e_tm1["s_t"] - s_t = e_t["s_t"] - - # Do forward passes through the networks and calculate the losses - with tf.GradientTape(persistent=True) as tape: - q_acts = [] # Q vals - q_targets = [] # Target Q vals - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - - o_tm1_feed, o_t_feed, a_tm1_feed = self._get_feed( - o_tm1, o_t, a_tm1, agent - ) - q_tm1 = self._q_networks[agent_key](o_tm1_feed) - q_t_value = self._target_q_networks[agent_key](o_t_feed) - q_t_selector = self._q_networks[agent_key](o_t_feed) - best_action = tf.argmax(q_t_selector, axis=1, output_type=tf.int32) - - # TODO Make use of q_t_selector for fingerprinting. Speak to Claude. - q_act = batched_index(q_tm1, a_tm1_feed, keepdims=True) # [B, 1] - q_target = batched_index( - q_t_value, best_action, keepdims=True - ) # [B, 1] - - q_acts.append(q_act) - q_targets.append(q_target) - - rewards = tf.concat( - [tf.reshape(val, (-1, 1)) for val in list(r_t.values())], axis=1 - ) - rewards = tf.reduce_mean(rewards, axis=1) # [B] - - pcont = tf.concat( - [tf.reshape(val, (-1, 1)) for val in list(d_t.values())], axis=1 - ) - pcont = tf.reduce_mean(pcont, axis=1) - discount = tf.cast(self._discount, list(d_t.values())[0].dtype) - pcont = discount * pcont # [B] - - q_acts = tf.concat(q_acts, axis=1) # [B, num_agents] - q_targets = tf.concat(q_targets, axis=1) # [B, num_agents] - - q_tot_mixed = self._mixing_network(q_acts, s_tm1) # [B, 1, 1] - q_tot_target_mixed = self._target_mixing_network( - q_targets, s_t - ) # [B, 1, 1] - - # Calculate Q loss. - targets = rewards + pcont * q_tot_target_mixed - td_error = targets - q_tot_mixed - - # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. - self.loss = 0.5 * tf.reduce_mean(tf.square(td_error)) - self.tape = tape - - def _backward(self) -> None: - """Trainer backward pass updating network parameters""" - - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - # Update agent networks - variables = [*self._q_networks[agent_key].trainable_variables] - gradients = self.tape.gradient(self.loss, variables) - gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] - self._optimizers[agent_key].apply(gradients, variables) - - # Update mixing network - variables = [*self._mixing_network.trainable_variables] - gradients = self.tape.gradient(self.loss, variables) - - gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] - self._optimizer.apply(gradients, variables) - - train_utils.safe_del(self, "tape") - - # TODO(Kale-ab): Ini _system_network_variables - def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: - """get network variables - - Args: - names (Sequence[str]): network names - - Returns: - Dict[str, Dict[str, np.ndarray]]: network variables - """ - - variables: Dict[str, Dict[str, np.ndarray]] = {} - variables = {} - for network_type in names: - if network_type == "mixing": - # Includes the hypernet variables - variables[network_type] = self._mixing_network.variables - else: # Collect variables for each agent network - variables[network_type] = { - key: tf2_utils.to_numpy( - self._system_network_variables[network_type][key] - ) - for key in self.unique_net_keys - } - return variables diff --git a/mava/systems/tf/value_decomposition/networks.py b/mava/systems/tf/value_decomposition/networks.py index 6250b1c59..fb661d5fd 100644 --- a/mava/systems/tf/value_decomposition/networks.py +++ b/mava/systems/tf/value_decomposition/networks.py @@ -63,7 +63,7 @@ def make_default_networks( """ if not value_networks_layer_sizes: - value_networks_layer_sizes = (128, 64) + value_networks_layer_sizes = (64, 64) value_network_func = snt.DeepRNN diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 84964e335..2f79e483e 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -35,7 +35,7 @@ from mava.utils import enums from mava.utils.loggers import MavaLogger from mava.systems.tf.madqn import MADQN -from mava.systems.tf.value_decomposition.mixer import QMIX, VDN +from mava.components.tf.modules.mixing.mixers import QMIX, VDN class ValueDecomposition(MADQN): diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index d9f3723aa..186911574 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -299,12 +299,14 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. value_loss = 0.5 * tf.square(td_error) - value_loss = tf.reduce_mean(value_loss) - # TODO zero padding mask + # Zero-padding mask + zero_padding_mask = tf.cast(extras["zero_padding_mask"], dtype=value_loss.dtype)[:-1] + masked_loss = value_loss * tf.expand_dims(zero_padding_mask, axis=-1) + masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(zero_padding_mask) - self.value_losses = {agent: value_loss for agent in self._agents} - self.mixer_loss = value_loss + self.value_losses = {agent: masked_loss for agent in self._agents} + self.mixer_loss = masked_loss self.tape = tape diff --git a/mava/systems/tf/vdn/README.md b/mava/systems/tf/vdn/README.md deleted file mode 100644 index acfdf349e..000000000 --- a/mava/systems/tf/vdn/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Value Decomposition Networks (VDN) - -An implementaiton of a VDN system ([Sunehag et al., 2017]). VDN learns a decomposed joint Q-value function for a team of agents in a cooperative multi-agent setting. - -

- -

- -[Sunehag et al., 2017]: https://arxiv.org/pdf/1706.05296 diff --git a/mava/systems/tf/vdn/__init__.py b/mava/systems/tf/vdn/__init__.py deleted file mode 100644 index 3a36aa133..000000000 --- a/mava/systems/tf/vdn/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mava.systems.tf.vdn.execution import VDNFeedForwardExecutor -from mava.systems.tf.vdn.networks import make_default_networks -from mava.systems.tf.vdn.system import VDN -from mava.systems.tf.vdn.training import VDNTrainer diff --git a/mava/systems/tf/vdn/builder.py b/mava/systems/tf/vdn/builder.py deleted file mode 100644 index 065ba78f2..000000000 --- a/mava/systems/tf/vdn/builder.py +++ /dev/null @@ -1,172 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""VDN system builder implementation.""" - -import dataclasses -from typing import Any, Dict, Iterator, Optional, Type - -import reverb -import sonnet as snt -from acme.utils import counting - -from mava import core, types -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.components.tf.modules.stabilising import FingerPrintStabalisation -from mava.systems.tf.madqn.builder import MADQNBuilder, MADQNConfig -from mava.systems.tf.vdn import execution, training -from mava.wrappers import DetailedTrainerStatistics - - -@dataclasses.dataclass -class VDNConfig(MADQNConfig): - """Configuration options for the VDN system. - - environment_spec: description of the action and observation spaces etc. for - each agent in the system. - epsilon_min: final minimum value for epsilon at the end of a decay schedule. - epsilon_decay: the rate at which epislon decays. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - target_update_period: number of learner steps to perform before updating - the target networks. - executor_variable_update_period: the rate at which executors sync their - paramters with the trainer. - max_gradient_norm: value to specify the maximum clipping value for the gradient - norm during optimization. - min_replay_size: minimum replay size before updating. - max_replay_size: maximum replay size. - samples_per_insert: number of samples to take from replay for every insert - that is made. - prefetch_size: size to prefetch from replay. - batch_size: batch size for updates. - n_step: number of steps to include prior to boostrapping. - sequence_length: recurrent sequence rollout length. - period: consecutive starting points for overlapping rollouts across a sequence. - discount: discount to use for TD updates. - checkpoint: boolean to indicate whether to checkpoint models. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - optimizer: type of optimizer to use for updating the parameters of models. - replay_table_name: string indicating what name to give the replay table. - checkpoint_subpath: subdirectory specifying where to store checkpoints. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - -class VDNBuilder(MADQNBuilder): - """Builder for VDN which constructs individual components of the system.""" - - def __init__( - self, - config: VDNConfig, - trainer_fn: Type[training.VDNTrainer] = training.VDNTrainer, - executor_fn: Type[core.Executor] = execution.VDNFeedForwardExecutor, - extra_specs: Dict[str, Any] = {}, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, - ) -> None: - """Initialise the system. - - Args: - config (VDNConfig): system configuration specifying hyperparameters and - additional information for constructing the system. - trainer_fn (Type[training.VDNTrainer], optional): Trainer function, of a - correpsonding type to work with the selected system architecture. - Defaults to training.VDNTrainer. - executor_fn (Type[core.Executor], optional): Executor function, of a - corresponding type to work with the selected system architecture. - Defaults to execution.VDNFeedForwardExecutor. - extra_specs (Dict[str, Any], optional): defines the specifications of extra - information used by the system. Defaults to {}. - replay_stabilisation_fn (Optional[Type[FingerPrintStabalisation]], - optional): optional function to stabilise experience replay. Defaults - to None. - """ - super(VDNBuilder, self).__init__( - config=config, - trainer_fn=trainer_fn, - executor_fn=executor_fn, - extra_specs=extra_specs, - replay_stabilisation_fn=replay_stabilisation_fn, - ) - - def make_trainer( - self, - networks: Dict[str, Dict[str, snt.Module]], - dataset: Iterator[reverb.ReplaySample], - counter: Optional[counting.Counter] = None, - logger: Optional[types.NestedLogger] = None, - communication_module: Optional[BaseCommunicationModule] = None, - replay_client: Optional[reverb.TFClient] = None, - ) -> core.Trainer: - """Create a trainer instance. - - Args: - networks (Dict[str, Dict[str, snt.Module]]): system networks. - dataset (Iterator[reverb.ReplaySample]): dataset iterator to feed data to - the trainer networks. - counter (Optional[counting.Counter], optional): a Counter which allows for - recording of counts, e.g. trainer steps. Defaults to None. - logger (Optional[types.NestedLogger], optional): Logger object for logging - metadata.. Defaults to None. - communication_module (BaseCommunicationModule): module to enable - agent communication. Defaults to None. - replay_client (reverb.TFClient): Used for importance sampling. - Not implemented yet. - - Returns: - core.Trainer: system trainer, that uses the collected data from the - executors to update the parameters of the agent networks in the system. - """ - - agents = self._config.environment_spec.get_agent_ids() - agent_types = self._config.environment_spec.get_agent_types() - - q_networks = networks["agent_networks"]["values"] - target_q_networks = networks["agent_networks"]["target_values"] - mixing_network = networks["mixing"] - target_mixing_network = networks["target_mixing"] - - # Check if we should use fingerprints - fingerprint = True if self._replay_stabiliser_fn is not None else False - - # The learner updates the parameters (and initializes them). - trainer = self._trainer_fn( # type:ignore - agents=agents, - agent_types=agent_types, - discount=self._config.discount, - q_networks=q_networks, - target_q_networks=target_q_networks, - mixing_network=mixing_network, - target_mixing_network=target_mixing_network, - agent_net_keys=self._config.agent_net_keys, - optimizer=self._config.optimizer, - target_update_period=self._config.target_update_period, - max_gradient_norm=self._config.max_gradient_norm, - communication_module=communication_module, - dataset=dataset, - counter=counter, - fingerprint=fingerprint, - logger=logger, - checkpoint_minute_interval=self._config.checkpoint_minute_interval, - checkpoint=self._config.checkpoint, - checkpoint_subpath=self._config.checkpoint_subpath, - learning_rate_scheduler_fn=self._config.learning_rate_scheduler_fn, - ) - - trainer = DetailedTrainerStatistics(trainer) # type:ignore - - return trainer diff --git a/mava/systems/tf/vdn/execution.py b/mava/systems/tf/vdn/execution.py deleted file mode 100644 index d798c8bff..000000000 --- a/mava/systems/tf/vdn/execution.py +++ /dev/null @@ -1,81 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""VDN system executor implementation.""" - -from typing import Dict, Optional - -import sonnet as snt -from acme.tf import variable_utils as tf2_variable_utils - -from mava import adders -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor -from mava.systems.tf.madqn.training import MADQNTrainer - - -class VDNFeedForwardExecutor(MADQNFeedForwardExecutor): - """A feed-forward executor. - An executor based on a feed-forward policy for each agent in the system. - """ - - def __init__( - self, - q_networks: Dict[str, snt.Module], - action_selectors: Dict[str, snt.Module], - trainer: MADQNTrainer, - agent_net_keys: Dict[str, str], - adder: Optional[adders.ParallelAdder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - communication_module: Optional[BaseCommunicationModule] = None, - fingerprint: bool = False, - evaluator: bool = False, - interval: Optional[dict] = None, - ): - """Initialise the system executor - - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - trainer (MADQNTrainer, optional): system trainer. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - adder (Optional[adders.ParallelAdder], optional): adder which sends data - to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. Defaults to None. - fingerprint (bool, optional): whether to use fingerprint stabilisation to - stabilise experience replay. Defaults to False. - evaluator (bool, optional): whether the executor will be used for - evaluation. Defaults to False. - interval: interval that evaluations are run at. - """ - - super(VDNFeedForwardExecutor, self).__init__( - q_networks=q_networks, - action_selectors=action_selectors, - agent_net_keys=agent_net_keys, - adder=adder, - variable_client=variable_client, - communication_module=communication_module, - fingerprint=fingerprint, - trainer=trainer, - evaluator=evaluator, - interval=interval, - ) diff --git a/mava/systems/tf/vdn/networks.py b/mava/systems/tf/vdn/networks.py deleted file mode 100644 index cf0c0023c..000000000 --- a/mava/systems/tf/vdn/networks.py +++ /dev/null @@ -1,66 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, Mapping, Optional, Sequence, Union - -from acme import types - -from mava import specs as mava_specs -from mava.systems.tf.madqn.networks import ( - make_default_networks as make_default_networks_madqn, -) -from mava.utils.enums import ArchitectureType, Network - - -# Default networks for vdn -def make_default_networks( - environment_spec: mava_specs.MAEnvironmentSpec, - agent_net_keys: Dict[str, str], - policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (128, 128), - archecture_type: ArchitectureType = ArchitectureType.feedforward, - network_type: Network = Network.mlp, - fingerprints: bool = False, - seed: Optional[int] = None, -) -> Mapping[str, types.TensorTransformation]: - """Default networks for vdn. - - Args: - environment_spec (mava_specs.MAEnvironmentSpec): description of the action and - observation spaces etc. for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): - size of policy networks. Defaults to (128,128). - archecture_type (ArchitectureType, optional): archecture used - for agent networks. Can be feedforward or recurrent. - Defaults to ArchitectureType.recurrent. - network_type (Network, optional): Agent network type. Can be mlp, - atari_dqn_network or coms_network. Defaults to Network.coms_network. - fingerprints (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - seed (int, optional): random seed for network initialization. - - Returns: - Mapping[str, types.TensorTransformation]: returned agent networks. - """ - - return make_default_networks_madqn( - environment_spec=environment_spec, - policy_networks_layer_sizes=policy_networks_layer_sizes, - agent_net_keys=agent_net_keys, - archecture_type=archecture_type, - network_type=network_type, - fingerprints=fingerprints, - seed=seed, - ) diff --git a/mava/systems/tf/vdn/system.py b/mava/systems/tf/vdn/system.py deleted file mode 100644 index fe9f55fbd..000000000 --- a/mava/systems/tf/vdn/system.py +++ /dev/null @@ -1,310 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""VDN system implementation.""" - -import functools -from typing import Any, Callable, Dict, Optional, Type, Union - -import dm_env -import reverb -import sonnet as snt -from acme import specs as acme_specs -from acme.utils import counting - -import mava -from mava import core -from mava import specs as mava_specs -from mava.components.tf.architectures import DecentralisedValueActor -from mava.components.tf.modules import mixing -from mava.environment_loop import ParallelEnvironmentLoop -from mava.systems.tf import executors -from mava.systems.tf.madqn.system import MADQN -from mava.systems.tf.vdn import builder, execution, training -from mava.types import EpsilonScheduler -from mava.utils.loggers import MavaLogger, logger_utils - - -# TODO Implement recurrent VDN -class VDN(MADQN): - """VDN system.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], - exploration_scheduler_fn: Union[ - EpsilonScheduler, - Dict[str, EpsilonScheduler], - Dict[str, Dict[str, EpsilonScheduler]], - ], - logger_factory: Callable[[str], MavaLogger] = None, - architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, - trainer_fn: Type[training.VDNTrainer] = training.VDNTrainer, - executor_fn: Type[core.Executor] = execution.VDNFeedForwardExecutor, - mixer: Type[mixing.BaseMixingModule] = mixing.AdditiveMixing, - num_executors: int = 1, - num_caches: int = 0, - environment_spec: mava_specs.MAEnvironmentSpec = None, - shared_weights: bool = True, - agent_net_keys: Dict[str, str] = {}, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 32.0, - n_step: int = 5, - sequence_length: int = 20, - importance_sampling_exponent: Optional[float] = None, - max_priority_weight: float = 0.9, - period: int = 20, - max_gradient_norm: float = None, - discount: float = 0.99, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( - learning_rate=1e-4 - ), - target_update_period: int = 100, - executor_variable_update_period: int = 1000, - max_executor_steps: int = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - checkpoint_minute_interval: int = 5, - logger_config: Dict = {}, - train_loop_fn: Callable = ParallelEnvironmentLoop, - eval_loop_fn: Callable = ParallelEnvironmentLoop, - train_loop_fn_kwargs: Dict = {}, - eval_loop_fn_kwargs: Dict = {}, - evaluator_interval: Optional[dict] = None, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - seed: Optional[int] = None, - ): - """Initialise the system - - Args: - environment_factory (Callable[[bool], dm_env.Environment]): function to - instantiate an environment. - network_factory (Callable[[acme_specs.BoundedArray], - Dict[str, snt.Module]]): function to instantiate system networks. - logger_factory (Callable[[str], MavaLogger], optional): function to - instantiate a system logger. Defaults to None. - architecture (Type[DecentralisedValueActor], optional): system architecture, - e.g. decentralised or centralised. Defaults to DecentralisedValueActor. - trainer_fn (Type[training.VDNTrainer], optional): training type associated - with executor and architecture, e.g. centralised training. Defaults - to training.VDNTrainer. - executor_fn (Type[core.Executor], optional): executor type, e.g. - feedforward or recurrent. Defaults to execution.VDNFeedForwardExecutor. - mixer (Type[mixing.BaseMixingModule], optional): mixer module type, e.g. - additive or monotonic mixing. Defaults to mixing.AdditiveMixing. - exploration_scheduler_fn (Type[ LinearExplorationScheduler ], optional): - function specifying a decaying scheduler for epsilon exploration. - See mava/systems/tf/madqn/system.py for details. - num_executors (int, optional): number of executor processes to run in - parallel. Defaults to 1. - num_caches (int, optional): number of trainer node caches. Defaults to 0. - environment_spec (mava_specs.MAEnvironmentSpec, optional): description of - the action, observation spaces etc. for each agent in the system. - Defaults to None. - shared_weights (bool, optional): whether agents should share weights or not. - When agent_net_keys are provided the value of shared_weights is ignored. - Defaults to True. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - batch_size (int, optional): sample batch size for updates. Defaults to 256. - prefetch_size (int, optional): size to prefetch from replay. Defaults to 4. - min_replay_size (int, optional): minimum replay size before updating. - Defaults to 1000. - max_replay_size (int, optional): maximum replay size. Defaults to 1000000. - samples_per_insert (Optional[float], optional): number of samples to take - from replay for every insert that is made. Defaults to 32.0. - n_step (int, optional): number of steps to include prior to boostrapping. - Defaults to 5. - sequence_length (int, optional): recurrent sequence rollout length. Defaults - to 20. - importance_sampling_exponent: (float): Not implemented yet. - max_priority_weight(float): Not implemented yet. - period (int, optional): The period with which we add sequences. See `period` - in `acme.SequenceAdder.period` for more info. Defaults to 20. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - discount (float, optional): discount factor to use for TD updates. Defaults - to 0.99. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]], optional): - type of optimizer to use to update network parameters. Defaults to - snt.optimizers.Adam( learning_rate=1e-4 ). - target_update_period (int, optional): number of steps before target - networks are updated. Defaults to 100. - executor_variable_update_period (int, optional): number of steps before - updating executor variables from the variable source. Defaults to 1000. - max_executor_steps (int, optional): maximum number of steps and executor - can in an episode. Defaults to None. - checkpoint (bool, optional): whether to checkpoint models. Defaults to - False. - checkpoint_subpath (str, optional): subdirectory specifying where to store - checkpoints. Defaults to "~/mava/". - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - logger_config (Dict, optional): additional configuration settings for the - logger factory. Defaults to {}. - train_loop_fn (Callable, optional): function to instantiate a train loop. - Defaults to ParallelEnvironmentLoop. - eval_loop_fn (Callable, optional): function to instantiate an evaluation - loop. Defaults to ParallelEnvironmentLoop. - train_loop_fn_kwargs (Dict, optional): possible keyword arguments to send - to the training loop. Defaults to {}. - eval_loop_fn_kwargs (Dict, optional): possible keyword arguments to send to - the evaluation loop. Defaults to {}. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - seed: seed for reproducible sampling (for epsilon greedy action selection). - evaluator_interval: An optional condition that is used to - evaluate/test system performance after [evaluator_interval] - condition has been met. If None, evaluation will - happen at every timestep. - E.g. to evaluate a system after every 100 executor episodes, - evaluator_interval = {"executor_episodes": 100}. - """ - - self._mixer = mixer - - # set default logger if no logger provided - if not logger_factory: - logger_factory = functools.partial( - logger_utils.make_logger, - directory="~/mava", - to_terminal=True, - time_delta=10, - ) - - super(VDN, self).__init__( - architecture=architecture, - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - environment_spec=environment_spec, - shared_weights=shared_weights, - agent_net_keys=agent_net_keys, - num_executors=num_executors, - num_caches=num_caches, - max_executor_steps=max_executor_steps, - checkpoint_subpath=checkpoint_subpath, - checkpoint=checkpoint, - checkpoint_minute_interval=checkpoint_minute_interval, - train_loop_fn=train_loop_fn, - train_loop_fn_kwargs=train_loop_fn_kwargs, - eval_loop_fn=eval_loop_fn, - eval_loop_fn_kwargs=eval_loop_fn_kwargs, - logger_config=logger_config, - exploration_scheduler_fn=exploration_scheduler_fn, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - seed=seed, - evaluator_interval=evaluator_interval, - ) - - if issubclass(executor_fn, executors.RecurrentExecutor): - extra_specs = self._get_extra_specs() - else: - extra_specs = {} - - self._builder = builder.VDNBuilder( - builder.VDNConfig( - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - executor_variable_update_period=executor_variable_update_period, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - n_step=n_step, - sequence_length=sequence_length, - importance_sampling_exponent=importance_sampling_exponent, - max_priority_weight=max_priority_weight, - period=period, - max_gradient_norm=max_gradient_norm, - checkpoint=checkpoint, - optimizer=optimizer, - checkpoint_subpath=checkpoint_subpath, - checkpoint_minute_interval=checkpoint_minute_interval, - evaluator_interval=evaluator_interval, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ), - trainer_fn=trainer_fn, - executor_fn=executor_fn, - extra_specs=extra_specs, - ) - - def trainer( - self, - replay: reverb.Client, - counter: counting.Counter, - ) -> mava.core.Trainer: - """System trainer - - Args: - replay (reverb.Client): replay data table to pull data from. - counter (counting.Counter): step counter object. - - Returns: - mava.core.Trainer: system trainer. - """ - - # Create the networks to optimize (online) - networks = self._network_factory( # type: ignore - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - ) - - # Create system architecture - architecture = self._architecture( - environment_spec=self._environment_spec, - value_networks=networks["q_networks"], - agent_net_keys=self._agent_net_keys, - ) - # Augment network architecture by adding mixing layer network. - system_networks = self._mixer( - architecture=architecture, - ).create_system() - - # create logger - trainer_logger_config = {} - if self._logger_config and "trainer" in self._logger_config: - trainer_logger_config = self._logger_config["trainer"] - trainer_logger = self._logger_factory( # type: ignore - "trainer", **trainer_logger_config - ) - - dataset = self._builder.make_dataset_iterator(replay) - counter = counting.Counter(counter, "trainer") - - return self._builder.make_trainer( - networks=system_networks, - dataset=dataset, - counter=counter, - logger=trainer_logger, - ) - - def build(self, name: str = "vdn") -> Any: - """Build the distributed system as a graph program. - - Args: - name (str, optional): system name. Defaults to "vdn". - - Returns: - Any: graph program for distributed system training. - """ - return super().build(name=name) diff --git a/mava/systems/tf/vdn/training.py b/mava/systems/tf/vdn/training.py deleted file mode 100644 index 9233440a1..000000000 --- a/mava/systems/tf/vdn/training.py +++ /dev/null @@ -1,245 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""VDN system trainer implementation.""" - -from typing import Any, Callable, Dict, List, Optional, Union - -import reverb -import sonnet as snt -import tensorflow as tf -from acme.utils import counting, loggers -from trfl.indexing_ops import batched_index - -from mava import types as mava_types -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.training import MADQNTrainer -from mava.utils import training_utils as train_utils - -train_utils.set_growing_gpu_memory() - - -class VDNTrainer(MADQNTrainer): - """VDN trainer. - This is the trainer component of a VDN system. i.e. it takes a dataset as input - and implements update functionality to learn from this dataset. - """ - - def __init__( - self, - agents: List[str], - agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], - mixing_network: snt.Module, - target_mixing_network: snt.Module, - target_update_period: int, - dataset: tf.data.Dataset, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], - discount: float, - agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, - communication_module: Optional[BaseCommunicationModule] = None, - max_gradient_norm: float = None, - counter: counting.Counter = None, - fingerprint: bool = False, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - ) -> None: - """Initialise VDN trainer - - Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - mixing_network (snt.Module): mixing networks learning factorised q-value - weights. - target_mixing_network (snt.Module): target mixing networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - communication_module (BaseCommunicationModule): module for communication - between agents. Defaults to None. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - counter (counting.Counter, optional): step counter object. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - self._mixing_network = mixing_network - self._target_mixing_network = target_mixing_network - - super(VDNTrainer, self).__init__( - agents=agents, - agent_types=agent_types, - q_networks=q_networks, - target_q_networks=target_q_networks, - target_update_period=target_update_period, - dataset=dataset, - optimizer=optimizer, - discount=discount, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - communication_module=communication_module, - max_gradient_norm=max_gradient_norm, - counter=counter, - fingerprint=fingerprint, - logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) - - @tf.function - def _step( - self, - ) -> Dict[str, Dict[str, Any]]: - """Trainer forward and backward passes.""" - - # Update the target networks - self._update_target_networks() - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - - self._forward(inputs) - - self._backward() - - # Log losses per agent - return {agent: {"policy_loss": self.loss} for agent in self._agents} - - def _forward(self, inputs: reverb.ReplaySample) -> None: - """Trainer forward pass - - Args: - inputs (Any): input data from the data table (transitions) - """ - - # Unpack input data as follows: - # o_tm1 = dictionary of observations one for each agent - # a_tm1 = dictionary of actions taken from obs in o_tm1 - # e_tm1 [Optional] = extra data that the agents persist in replay. - # r_t = dictionary of rewards or rewards sequences - # (if using N step transitions) ensuing from actions a_tm1 - # d_t = environment discount ensuing from actions a_tm1. - # This discount is applied to future rewards after r_t. - # o_t = dictionary of next observations or next observation sequences - # e_t = [Optional] = extra data that the agents persist in replay. - trans = mava_types.Transition(*inputs.data) - - o_tm1, o_t, a_tm1, r_t, d_t, _, _ = ( - trans.observations, - trans.next_observations, - trans.actions, - trans.rewards, - trans.discounts, - trans.extras, - trans.next_extras, - ) - - # Do forward passes through the networks and calculate the losses - with tf.GradientTape(persistent=True) as tape: - q_acts = [] # Q vals - q_targets = [] # Target Q vals - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - - o_tm1_feed, o_t_feed, a_tm1_feed = self._get_feed( - o_tm1, o_t, a_tm1, agent - ) - q_tm1 = self._q_networks[agent_key](o_tm1_feed) - q_t_value = self._target_q_networks[agent_key](o_t_feed) - q_t_selector = self._q_networks[agent_key](o_t_feed) - best_action = tf.argmax(q_t_selector, axis=1, output_type=tf.int32) - - # TODO Make use of q_t_selector for fingerprinting. Speak to Claude. - q_act = batched_index(q_tm1, a_tm1_feed, keepdims=True) # [B, 1] - q_target = batched_index( - q_t_value, best_action, keepdims=True - ) # [B, 1] - - q_acts.append(q_act) - q_targets.append(q_target) - - rewards = tf.concat( - [tf.reshape(val, (-1, 1)) for val in list(r_t.values())], axis=1 - ) - rewards = tf.reduce_mean(rewards, axis=1) # [B] - - pcont = tf.concat( - [tf.reshape(val, (-1, 1)) for val in list(d_t.values())], axis=1 - ) - pcont = tf.reduce_mean(pcont, axis=1) - discount = tf.cast(self._discount, list(d_t.values())[0].dtype) - pcont = discount * pcont # [B] - - q_acts = tf.concat(q_acts, axis=1) # [B, num_agents] - q_targets = tf.concat(q_targets, axis=1) # [B, num_agents] - - q_tot_mixed = self._mixing_network(q_acts) # [B, 1, 1] - q_tot_target_mixed = self._target_mixing_network(q_targets) # [B, 1, 1] - - # q_tot_mixed = tf.reduce_sum(q_acts, axis=1) # [B, 1, 1] - # q_tot_target_mixed = tf.reduce_sum(q_targets, axis=1) # [B, 1, 1] - - # Calculate Q loss. - targets = rewards + pcont * q_tot_target_mixed - targets = tf.stop_gradient(targets) - td_error = targets - q_tot_mixed - - # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. - self.loss = 0.5 * tf.reduce_mean(tf.square(td_error)) - self.tape = tape - - def _backward(self) -> None: - """Trainer backward pass updating network parameters""" - - # Calculate the gradients and update the networks - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - # Get trainable variables. - trainable_variables = self._q_networks[agent_key].trainable_variables - - # Compute gradients. - gradients = self.tape.gradient(self.loss, trainable_variables) - - # Clip gradients. - gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] - - # Apply gradients. - self._optimizers[agent_key].apply(gradients, trainable_variables) - - # Delete the tape manually because of the persistent=True flag. - train_utils.safe_del(self, "tape") diff --git a/mava/utils/environments/flatland_utils.py b/mava/utils/environments/flatland_utils.py index 3b6dc20bf..b66d77714 100644 --- a/mava/utils/environments/flatland_utils.py +++ b/mava/utils/environments/flatland_utils.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple, Union - -import numpy as np +from typing import Optional from mava.wrappers.flatland import FlatlandEnvWrapper +from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation try: from flatland.envs.line_generators import sparse_line_generator @@ -33,8 +32,7 @@ except ModuleNotFoundError: pass - -def create_rail_env_with_tree_obs( +def _create_rail_env_with_tree_obs( n_agents: int = 5, x_dim: int = 30, y_dim: int = 30, @@ -47,7 +45,7 @@ def create_rail_env_with_tree_obs( malfunction_max_duration: int = 50, observation_max_path_depth: int = 30, observation_tree_depth: int = 2, -) -> "RailEnv": +) -> RailEnv: """Create a Flatland RailEnv with TreeObservation. Args: @@ -100,23 +98,49 @@ def create_rail_env_with_tree_obs( return rail_env -def flatland_env_factory( +def make_environment( + n_agents: int =10, + x_dim: int = 30, + y_dim: int = 30, + n_cities: int = 2, + max_rails_between_cities: int =2, + max_rails_in_city: int =3, + seed: int = 0, + malfunction_rate:float = 1/200, + malfunction_min_duration: int = 20, + malfunction_max_duration: int = 50, + observation_max_path_depth: int = 30, + observation_tree_depth: int = 2, + concat_prev_actions: bool = True, + concat_agent_id: bool = True, evaluation: bool = False, - env_config: Dict[str, Any] = {}, - preprocessor: Callable[ - [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] - ] = None, - include_agent_info: bool = False, random_seed: Optional[int] = None, ) -> FlatlandEnvWrapper: """Loads a flatand environment and wraps it using the flatland wrapper""" del evaluation # since it has same behaviour for both train and eval - env = create_rail_env_with_tree_obs(**env_config) - wrapped_env = FlatlandEnvWrapper(env, preprocessor, include_agent_info) + env = _create_rail_env_with_tree_obs( + n_agents=n_agents, + x_dim=x_dim, + y_dim=y_dim, + n_cities=n_cities, + max_rails_between_cities=max_rails_between_cities, + max_rails_in_city=max_rails_in_city, + seed=random_seed, + malfunction_rate=malfunction_rate, + malfunction_min_duration=malfunction_min_duration, + malfunction_max_duration=malfunction_max_duration, + observation_max_path_depth=observation_max_path_depth, + observation_tree_depth=observation_tree_depth, + ) + + env = FlatlandEnvWrapper(env) - if random_seed and hasattr(wrapped_env, "seed"): - wrapped_env.seed(random_seed) + if concat_prev_actions: + env = ConcatPrevActionToObservation(env) + + if concat_agent_id: + env = ConcatAgentIdToObservation(env) - return wrapped_env + return env diff --git a/mava/utils/environments/smac_utils.py b/mava/utils/environments/smac_utils.py index c94d4e00e..66b80b569 100644 --- a/mava/utils/environments/smac_utils.py +++ b/mava/utils/environments/smac_utils.py @@ -18,7 +18,9 @@ def make_environment(map_name="3m", concat_prev_actions=True, concat_agent_id=True, evaluation = False, random_seed=None): env = StarCraft2Env(map_name=map_name, seed=random_seed) + env = SMACWrapper(env) + if concat_prev_actions: env = ConcatPrevActionToObservation(env) diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index 37e220594..5e32327ae 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -15,9 +15,7 @@ """Wraps a Flatland MARL environment to be used as a dm_env environment.""" - import types as tp -import typing from functools import partial from typing import Any, Callable, Dict, List, Sequence, Tuple, Union @@ -30,10 +28,8 @@ from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import AgentRenderVariant, RenderTool - - _has_flatland = True + from flatland.envs.step_utils.states import TrainState except ModuleNotFoundError: - _has_flatland = False pass from gym.spaces import Discrete from gym.spaces.box import Box @@ -47,564 +43,571 @@ ) from mava.wrappers.env_wrappers import ParallelEnvWrapper -if _has_flatland: # noqa: C901 - - class FlatlandEnvWrapper(ParallelEnvWrapper): - """Environment wrapper for Flatland environments. - - All environments would require an observation preprocessor, except for - 'GlobalObsForRailEnv'. This is because flatland gives users the - flexibility of designing custom observation builders. 'TreeObsForRailEnv' - would use the normalize_observation function from the flatland baselines - if none is supplied. - - The supplied preprocessor should return either an array, tuple of arrays or - a dictionary of arrays for an observation input. - The obervation, for an agent, returned by this wrapper could consist of both - the agent observation and agent info. This is because flatland also provides - informationn about the agents at each step. This information include; - 'action_required', 'malfunction', 'speed', and 'status', and it can be appended - to the observation, by this wrapper, as an array. action_required is a boolean, - malfunction is an int denoting the number of steps for which the agent would - remain motionless, speed is a float and status can be any of the below; - - READY_TO_DEPART = 0 - ACTIVE = 1 - DONE = 2 - DONE_REMOVED = 3 - - This would be included in the observation if agent_info is set to True +class FlatlandEnvWrapper(ParallelEnvWrapper): + """Environment wrapper for Flatland environments. + All environments would require an observation preprocessor, except for + 'GlobalObsForRailEnv'. This is because flatland gives users the + flexibility of designing custom observation builders. 'TreeObsForRailEnv' + would use the normalize_observation function from the flatland baselines + if none is supplied. + The supplied preprocessor should return either an array, tuple of arrays or + a dictionary of arrays for an observation input. + The obervation, for an agent, returned by this wrapper could consist of both + the agent observation and agent info. This is because flatland also provides + informationn about the agents at each step. This information include; + 'action_required', 'malfunction', 'speed', and 'status', and it can be appended + to the observation, by this wrapper, as an array. action_required is a boolean, + malfunction is an int denoting the number of steps for which the agent would + remain motionless, speed is a float and status can be any of the below; + READY_TO_DEPART = 0 + ACTIVE = 1 + DONE = 2 + DONE_REMOVED = 3 + This would be included in the observation if agent_info is set to True + """ + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + def __init__( + self, + environment: RailEnv, + preprocessor: Callable[ + [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] + ] = None, + agent_info: bool = False, + ): + """Wrap Flatland environment. + Args: + environment: underlying RailEnv + preprocessor: optional preprocessor. Defaults to None. + agent_info: include agent info. Defaults to True. """ + self._environment = environment + decorate_step_method(self._environment) + + self._agents = [get_agent_id(i) for i in range(self.num_agents)] + self._possible_agents = self.agents[:] + + self._reset_next_step = True + self._step_type = dm_env.StepType.FIRST + self.num_actions = 5 + + self.action_spaces = { + agent: Discrete(self.num_actions) for agent in self.possible_agents + } + + # preprocessor must be for observation builders other than global obs + # treeobs builders would use the default preprocessor if none is + # supplied + self.preprocessor: Callable[ + [Dict[int, Any]], Dict[int, Any] + ] = self._obtain_preprocessor(preprocessor) + + self._include_agent_info = agent_info + + # observation space: + # flatland defines no observation space for an agent. Here we try + # to define the observation space. All agents are identical and would + # have the same observation space. + # Infer observation space based on returned observation + obs, _ = self._environment.reset() + obs = self.preprocessor(obs) + self.observation_spaces = { + get_agent_id(i): infer_observation_space(ob) for i, ob in obs.items() + } + + self._env_renderer = RenderTool( + self._environment, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=800, # Adjust these parameters to fit your resolution + screen_width=800, + ) # Adjust these parameters to fit your resolution + + @property + def agents(self) -> List[str]: + """Return list of active agents.""" + return self._agents + + @property + def possible_agents(self) -> List[str]: + """Return list of all possible agents.""" + return self._possible_agents + + def _update_stats(self, info, rewards): + episode_return = sum(list(rewards.values())) + tasks_finished = sum( + [1 if state == TrainState.DONE else 0 for state in info["state"].values()] + ) + completion = tasks_finished / len(self._agents) + normalized_score = episode_return / ( + self._environment._max_episode_steps * len(self._agents) + ) - # Note: we don't inherit from base.EnvironmentWrapper because that class - # assumes that the wrapped environment is a dm_env.Environment. - def __init__( - self, - environment: RailEnv, - preprocessor: Callable[ - [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] - ] = None, - agent_info: bool = True, - ): - """Wrap Flatland environment. - - Args: - environment: underlying RailEnv - preprocessor: optional preprocessor. Defaults to None. - agent_info: include agent info. Defaults to True. - """ - self._environment = environment - decorate_step_method(self._environment) - - self._agents = [get_agent_id(i) for i in range(self.num_agents)] - self._possible_agents = self.agents[:] - - self._reset_next_step = True - self._step_type = dm_env.StepType.FIRST - self.num_actions = 5 - - self.action_spaces = { - agent: Discrete(self.num_actions) for agent in self.possible_agents - } + self._latest_score = normalized_score + self._latest_completion = completion - # preprocessor must be for observation builders other than global obs - # treeobs builders would use the default preprocessor if none is - # supplied - self.preprocessor: Callable[ - [Dict[int, Any]], Dict[int, Any] - ] = self._obtain_preprocessor(preprocessor) - - self._include_agent_info = agent_info - - # observation space: - # flatland defines no observation space for an agent. Here we try - # to define the observation space. All agents are identical and would - # have the same observation space. - # Infer observation space based on returned observation - obs, _ = self._environment.reset() - obs = self.preprocessor(obs) - self.observation_spaces = { - get_agent_id(i): infer_observation_space(ob) for i, ob in obs.items() + def get_stats(self): + if self._latest_completion is not None and self._latest_score is not None: + return { + "score": self._latest_score, + "completion": self._latest_completion, } + else: + return {} - self._env_renderer = RenderTool( - self._environment, - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, - screen_height=600, # Adjust these parameters to fit your resolution - screen_width=800, - ) # Adjust these parameters to fit your resolution - - @property - def agents(self) -> List[str]: - """Return list of active agents.""" - return self._agents - - @property - def possible_agents(self) -> List[str]: - """Return list of all possible agents.""" - return self._possible_agents - - def render(self, mode: str = "human") -> np.ndarray: - """Renders the environment.""" - if mode == "human": - show = True - else: - show = False - - return self._env_renderer.render_env( - show=show, - show_observations=False, - show_predictions=False, - return_image=True, - ) - - def env_done(self) -> bool: - """Checks if the environment is done.""" - return self._environment.dones["__all__"] or not self.agents - - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - # Reset the rendering sytem - self._env_renderer.reset() + def render(self, mode: str = "human") -> np.array: + """Renders the environment.""" + if mode == "human": + show = True + else: + show = False - self._reset_next_step = False - self._agents = self.possible_agents[:] + return self._env_renderer.render_env( + show=show, + show_observations=False, + show_predictions=False, + return_image=True, + ) - observe, info = self._environment.reset() - observations = self._create_observations( - observe, info, self._environment.dones - ) - rewards_spec = self.reward_spec() + def env_done(self) -> bool: + """Checks if the environment is done.""" + return self._environment.dones["__all__"] or not self.agents + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + # Reset the rendering sytem + self._env_renderer.reset() + + self._reset_next_step = False + self._agents = self.possible_agents[:] + self._discounts = { + agent: np.dtype("float32").type(1.0) for agent in self.agents + } + observe, info = self._environment.reset() + observations = self._create_observations(observe, info, self._environment.dones) + rewards_spec = self.reward_spec() + rewards = { + agent: convert_np_type(rewards_spec[agent].dtype, 0) + for agent in self.possible_agents + } + + discount_spec = self.discount_spec() + self._discounts = { + agent: convert_np_type(discount_spec[agent].dtype, 1) + for agent in self.possible_agents + } + return parameterized_restart(rewards, self._discounts, observations), {} + + def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: + """Steps the environment.""" + self._pre_step() + + if self._reset_next_step: + return self.reset() + + self._agents = [ + agent + for agent in self.agents + if not self._environment.dones[get_agent_handle(agent)] + ] + + observations, rewards, dones, infos = self._environment.step(actions) + + rewards_spec = self.reward_spec() + # Handle empty rewards + if not rewards: rewards = { agent: convert_np_type(rewards_spec[agent].dtype, 0) for agent in self.possible_agents } + else: + rewards = { + get_agent_id(agent): convert_np_type( + rewards_spec[get_agent_id(agent)].dtype, reward + ) + for agent, reward in rewards.items() + } + + if observations: + observations = self._create_observations(observations, infos, dones) - discount_spec = self.discount_spec() - self._discounts = { - agent: convert_np_type(discount_spec[agent].dtype, 1) + if self.env_done(): + self._step_type = dm_env.StepType.LAST + self._reset_next_step = True + discounts = { + agent: convert_np_type( + self.discount_spec()[agent].dtype, 0 + ) # Zero discount on final step for agent in self.possible_agents } - return parameterized_restart(rewards, self._discounts, observations) - - def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: - """Steps the environment.""" - self._pre_step() - - if self._reset_next_step: - return self.reset() - - self._agents = [ - agent - for agent in self.agents - if not self._environment.dones[get_agent_handle(agent)] - ] - - observations, rewards, dones, infos = self._environment.step(actions) - - rewards_spec = self.reward_spec() - # Handle empty rewards - if not rewards: - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) - for agent in self.possible_agents - } - else: - rewards = { - get_agent_id(agent): convert_np_type( - rewards_spec[get_agent_id(agent)].dtype, reward - ) - for agent, reward in rewards.items() - } - - if observations: - observations = self._create_observations(observations, infos, dones) - - if self.env_done(): - self._step_type = dm_env.StepType.LAST - self._reset_next_step = True - - # Zero discount when env done - discounts = { - agent: convert_np_type( - self.discount_spec()[agent].dtype, 0 - ) # Zero discount on final step - for agent in self.possible_agents - } - else: - self._step_type = dm_env.StepType.MID - discounts = self._discounts - - return dm_env.TimeStep( - observation=observations, - reward=rewards, - discount=discounts, - step_type=self._step_type, - ) + self._update_stats(infos, rewards) + # TODO (Claude) zero discount! + else: + self._step_type = dm_env.StepType.MID + discounts = self._discounts # discount == 1 + + return dm_env.TimeStep( + observation=observations, + reward=rewards, + discount=discounts, + step_type=self._step_type, + ), {} + + # Convert Flatland observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. + def _convert_observations( + self, observes: Dict[str, Tuple[np.array, np.ndarray]], dones: Dict[str, bool] + ) -> Observation: + return convert_dm_compatible_observations( + observes, + dones, + self.observation_spec(), + self.env_done(), + self.possible_agents, + ) - # Convert Flatland observation so it's dm_env compatible. Also, the list - # of legal actions must be converted to a legal actions mask. - def _convert_observations( - self, - observes: Dict[str, Tuple[np.ndarray, np.ndarray]], - dones: Dict[str, bool], - ) -> Observation: - return convert_dm_compatible_observations( - observes, # type: ignore - dones, - self.observation_spec(), - self.env_done(), - self.possible_agents, + # collate agent info and observation into a tuple, making the agents obervation to + # be a tuple of the observation from the env and the agent info + def _collate_obs_and_info( + self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] + ) -> Dict[str, Tuple[np.array, np.ndarray]]: + observations: Dict[str, Tuple[np.array, np.ndarray]] = {} + observes = self.preprocessor(observes) + for agent, obs in observes.items(): + agent_id = get_agent_id(agent) + agent_info = np.array( + [info[k][agent] for k in sort_str_num(info.keys())], dtype=np.float32 ) - - # collate agent info and observation into a tuple, making the agents obervation - # to be a tuple of the observation from the env and the agent info - def _collate_obs_and_info( - self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] - ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: - observations: Dict[str, Tuple[np.ndarray, np.ndarray]] = {} - observes = self.preprocessor(observes) - for agent, obs in observes.items(): - agent_id = get_agent_id(agent) - agent_info = np.array( - [info[k][agent] for k in sort_str_num(info.keys())], - dtype=np.float32, - ) - obs = (obs, agent_info) if self._include_agent_info else obs # type: ignore # noqa: E501 - observations[agent_id] = obs # type: ignore - - return observations - - def _create_observations( - self, - obs: Dict[int, np.ndarray], - info: Dict[str, Dict[int, Any]], - dones: Dict[int, bool], - ) -> Observation: - """Convert observation.""" - observations_ = self._collate_obs_and_info(obs, info) - dones_ = {get_agent_id(k): v for k, v in dones.items()} - observations = self._convert_observations(observations_, dones_) - return observations - - def _obtain_preprocessor( - self, preprocessor: Any - ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: - """Obtains the actual preprocessor. - - Obtains the actual preprocessor to be used based on the supplied - preprocessor and the env's obs_builder object - """ - if not isinstance(self.obs_builder, GlobalObsForRailEnv): - _preprocessor = preprocessor if preprocessor else lambda x: x - if isinstance(self.obs_builder, TreeObsForRailEnv): - _preprocessor = ( - partial( - normalize_observation, tree_depth=self.obs_builder.max_depth - ) - if not preprocessor - else preprocessor - ) - assert _preprocessor is not None - else: - - def _preprocessor( - x: Tuple[np.ndarray, np.ndarray, np.ndarray] - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - return x - - def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: - temp_obs = {} - for agent_id, ob in obs.items(): - temp_obs[agent_id] = _preprocessor(ob) - return temp_obs - - return returned_preprocessor - - # set all parameters that should be available before an environment step - # if no available agent, then environment is done and should be reset - def _pre_step(self) -> None: - if not self.agents: - self._step_type = dm_env.StepType.LAST - - def observation_spec(self) -> Dict[str, OLT]: - """Return observation spec.""" - observation_specs = {} - for agent in self.agents: - observation_specs[agent] = OLT( - observation=tuple( - ( - _convert_to_spec(self.observation_spaces[agent]), - agent_info_spec(), - ) + obs = (obs, agent_info) if self._include_agent_info else obs + observations[agent_id] = obs + + return observations + + def _create_observations( + self, + obs: Dict[int, np.ndarray], + info: Dict[str, Dict[int, Any]], + dones: Dict[int, bool], + ) -> Observation: + """Convert observation.""" + observations_ = self._collate_obs_and_info(obs, info) + dones_ = {get_agent_id(k): v for k, v in dones.items()} + observations = self._convert_observations(observations_, dones_) + return observations + + def _obtain_preprocessor( + self, preprocessor: Any + ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: + """Obtains the actual preprocessor. + Obtains the actual preprocessor to be used based on the supplied + preprocessor and the env's obs_builder object + """ + if not isinstance(self.obs_builder, GlobalObsForRailEnv): + _preprocessor = preprocessor if preprocessor else lambda x: x + if isinstance(self.obs_builder, TreeObsForRailEnv): + _preprocessor = ( + partial( + normalize_observation, tree_depth=self.obs_builder.max_depth ) - if self._include_agent_info - else _convert_to_spec(self.observation_spaces[agent]), - legal_actions=_convert_to_spec(self.action_spaces[agent]), - terminal=specs.Array((1,), np.float32), - ) - return observation_specs - - def action_spec( - self, - ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: - """Get action spec.""" - action_specs = {} - action_spaces = self.action_spaces - for agent in self.possible_agents: - action_specs[agent] = _convert_to_spec(action_spaces[agent]) - return action_specs - - def reward_spec(self) -> Dict[str, specs.Array]: - """Get the reward spec.""" - reward_specs = {} - for agent in self.possible_agents: - reward_specs[agent] = specs.Array((), np.float32) - return reward_specs - - def discount_spec(self) -> Dict[str, specs.BoundedArray]: - """Get the discount spec.""" - discount_specs = {} - for agent in self.possible_agents: - discount_specs[agent] = specs.BoundedArray( - (), np.float32, minimum=0, maximum=1.0 + if not preprocessor + else preprocessor ) - return discount_specs - - def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Get the extras spec.""" - return {} + assert _preprocessor is not None + else: - def seed(self, seed: int = None) -> None: - """Seed the environment.""" - self._environment._seed(seed) - - @property - def environment(self) -> RailEnv: - """Returns the wrapped environment.""" - return self._environment - - @property - def num_agents(self) -> int: - """Returns the number of trains/agents in the flatland environment""" - return int(self._environment.number_of_agents) - - def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" - return getattr(self._environment, name) - - # Utility functions - - def infer_observation_space( - obs: Union[tuple, np.ndarray, dict] - ) -> Union[Box, tuple, dict]: - """Infer a gym Observation space from a sample observation from flatland""" - if isinstance(obs, np.ndarray): - return Box( - -np.inf, - np.inf, - shape=obs.shape, - dtype=obs.dtype, + def _preprocessor( + x: Tuple[np.ndarray, np.ndarray, np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return x + + def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: + temp_obs = {} + for agent_id, ob in obs.items(): + temp_obs[agent_id] = _preprocessor(ob) + return temp_obs + + return returned_preprocessor + + # set all parameters that should be available before an environment step + # if no available agent, then environment is done and should be reset + def _pre_step(self) -> None: + if not self.agents: + self._step_type = dm_env.StepType.LAST + + def observation_spec(self) -> Dict[str, OLT]: + """Return observation spec.""" + observation_specs = {} + for agent in self.agents: + observation_specs[agent] = OLT( + observation=tuple( + ( + _convert_to_spec(self.observation_spaces[agent]), + agent_info_spec(), + ) + ) + if self._include_agent_info + else _convert_to_spec(self.observation_spaces[agent]), + legal_actions=_convert_to_spec(self.action_spaces[agent]), + terminal=specs.Array((1,), np.float32), ) - elif isinstance(obs, tuple): - return tuple(infer_observation_space(o) for o in obs) - elif isinstance(obs, dict): - return {key: infer_observation_space(value) for key, value in obs.items()} - else: - raise ValueError( - f"Unexpected observation type: {type(obs)}. " - f"Observation should be of either of this types " - f"(np.ndarray, tuple, or dict)" + return observation_specs + + def action_spec(self) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: + """Get action spec.""" + action_specs = {} + action_spaces = self.action_spaces + for agent in self.possible_agents: + action_specs[agent] = _convert_to_spec(action_spaces[agent]) + return action_specs + + def reward_spec(self) -> Dict[str, specs.Array]: + """Get the reward spec.""" + reward_specs = {} + for agent in self.possible_agents: + reward_specs[agent] = specs.Array((), np.float32) + return reward_specs + + def discount_spec(self) -> Dict[str, specs.BoundedArray]: + """Get the discount spec.""" + discount_specs = {} + for agent in self.possible_agents: + discount_specs[agent] = specs.BoundedArray( + (), np.float32, minimum=0, maximum=1.0 ) + return discount_specs + + def extra_spec(self) -> Dict[str, specs.BoundedArray]: + """Get the extras spec.""" + return {} + + def seed(self, seed: int = None) -> None: + """Seed the environment.""" + self._environment._seed(seed) + + @property + def environment(self) -> RailEnv: + """Returns the wrapped environment.""" + return self._environment + + @property + def num_agents(self) -> int: + """Returns the number of trains/agents in the flatland environment""" + print(self._environment.number_of_agents) + return int(self._environment.number_of_agents) + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment.""" + return getattr(self._environment, name) + + +# Utility functions + + +def infer_observation_space( + obs: Union[tuple, np.ndarray, dict] +) -> Union[Box, tuple, dict]: + """Infer a gym Observation space from a sample observation from flatland""" + if isinstance(obs, np.ndarray): + return Box(-np.inf, np.inf, shape=obs.shape, dtype=obs.dtype,) + elif isinstance(obs, tuple): + return tuple(infer_observation_space(o) for o in obs) + elif isinstance(obs, dict): + return {key: infer_observation_space(value) for key, value in obs.items()} + else: + raise ValueError( + f"Unexpected observation type: {type(obs)}. " + f"Observation should be of either of this types " + f"(np.ndarray, tuple, or dict)" + ) - def agent_info_spec() -> specs.BoundedArray: - """Create the spec for the agent_info part of the observation""" - return specs.BoundedArray((4,), dtype=np.float32, minimum=0.0, maximum=10) - - def get_agent_id(handle: int) -> str: - """Obtain the string that constitutes the agent id from an agent handle""" - return f"train_{handle}" - - def get_agent_handle(id: str) -> int: - """Obtain an agents handle given its id""" - return int(id.split("_")[-1]) - - def decorate_step_method(env: RailEnv) -> None: - """Step method decorator. - - Enable the step method of the env to take action dictionaries where agent keys - are the agent ids. Flatland uses the agent handles as keys instead. This - function decorates the step method so that it accepts an action dict where - the keys are the agent ids. - """ - env.step_ = env.step - - def _step( - self: RailEnv, actions: Dict[str, Union[int, float, Any]] - ) -> dm_env.TimeStep: - actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} - return self.step_(actions_) - - env.step = tp.MethodType(_step, env) - - # The block of code below is obtained from the flatland starter-kit - # at https://gitlab.aicrowd.com/flatland/flatland-starter-kit/-/blob/master/ - # utils/observation_utils.py - # this is done just to obtain the normalize_observation function that would - # serve as the default preprocessor for the Tree obs builder. - def max_lt(seq: Sequence, val: Any) -> Any: - """Get max in sequence. +def agent_info_spec() -> specs.BoundedArray: + """Create the spec for the agent_info part of the observation""" + return specs.BoundedArray((4,), dtype=np.float32, minimum=0.0, maximum=10) + + +def get_agent_id(handle: int) -> str: + """Obtain the string that constitutes the agent id from an agent handle - an int""" + return f"train_{handle}" + + +def get_agent_handle(id: str) -> int: + """Obtain an agents handle given its id""" + return int(id.split("_")[-1]) + + +def decorate_step_method(env: RailEnv) -> None: + """Step method decorator. + Enable the step method of the env to take action dictionaries where agent keys + are the agent ids. Flatland uses the agent handles as keys instead. This function + decorates the step method so that it accepts an action dict where the keys are the + agent ids. + """ + env.step_ = env.step + + def _step( + self: RailEnv, actions: Dict[str, Union[int, float, Any]] + ) -> dm_env.TimeStep: + actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} + return self.step_(actions_) + + env.step = tp.MethodType(_step, env) + + +# The block of code below is obtained from the flatland starter-kit +# at https://gitlab.aicrowd.com/flatland/flatland-starter-kit/-/blob/master/ +# utils/observation_utils.py +# this is done just to obtain the normalize_observation function that would +# serve as the default preprocessor for the Tree obs builder. + + +def max_lt(seq: Sequence, val: Any) -> Any: + """Get max in sequence. + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + max = 0 + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max + + +def min_gt(seq: Sequence, val: Any) -> Any: + """Gets min in a sequence. + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] >= val and seq[idx] < min: + min = seq[idx] + idx -= 1 + return min + + +def norm_obs_clip( + obs: np.ndarray, + clip_min: int = -1, + clip_max: int = 1, + fixed_radius: int = 0, + normalize_to_range: bool = False, +) -> np.ndarray: + """Normalize observation. + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + if fixed_radius > 0: + max_obs = fixed_radius + else: + max_obs = max(1, max_lt(obs, 1000)) + 1 + + min_obs = 0 # min(max_obs, min_gt(obs, 0)) + if normalize_to_range: + min_obs = min_gt(obs, 0) + if min_obs > max_obs: + min_obs = max_obs + if max_obs == min_obs: + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + + +def _split_node_into_feature_groups( + node: Node, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Splits node into features.""" + data = np.zeros(6) + distance = np.zeros(1) + agent_data = np.zeros(4) + + data[0] = node.dist_own_target_encountered + data[1] = node.dist_other_target_encountered + data[2] = node.dist_other_agent_encountered + data[3] = node.dist_potential_conflict + data[4] = node.dist_unusable_switch + data[5] = node.dist_to_next_branch + + distance[0] = node.dist_min_to_target + + agent_data[0] = node.num_agents_same_direction + agent_data[1] = node.num_agents_opposite_direction + agent_data[2] = node.num_agents_malfunctioning + agent_data[3] = node.speed_min_fractional + + return data, distance, agent_data + + +def _split_subtree_into_feature_groups( + node: Node, current_tree_depth: int, max_tree_depth: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Split subtree.""" + if node == -np.inf: + remaining_depth = max_tree_depth - current_tree_depth + # reference: + # https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure + num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) + return ( + [-np.inf] * num_remaining_nodes * 6, + [-np.inf] * num_remaining_nodes, + [-np.inf] * num_remaining_nodes * 4, + ) - Return greatest item in seq for which item < val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - max = 0 - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: - max = seq[idx] - idx -= 1 - return max - - def min_gt(seq: Sequence, val: Any) -> Any: - """Gets min in a sequence. - - Return smallest item in seq for which item > val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - min = np.inf - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] >= val and seq[idx] < min: - min = seq[idx] - idx -= 1 - return min - - @typing.no_type_check - def norm_obs_clip( - obs: np.ndarray, - clip_min: int = -1, - clip_max: int = 1, - fixed_radius: int = 0, - normalize_to_range: bool = False, - ) -> np.ndarray: - """Normalize observation. - - This function returns the difference between min and max value of an observation - :param obs: Observation that should be normalized - :param clip_min: min value where observation will be clipped - :param clip_max: max value where observation will be clipped - :return: returnes normalized and clipped observatoin - """ - if fixed_radius > 0: - max_obs = fixed_radius - else: - max_obs = max(1, max_lt(obs, 1000)) + 1 - - min_obs = 0 # min(max_obs, min_gt(obs, 0)) - if normalize_to_range: - min_obs = min_gt(obs, 0) - if min_obs > max_obs: - min_obs = max_obs - if max_obs == min_obs: - return np.clip(np.array(obs) / max_obs, clip_min, clip_max) - norm = np.abs(max_obs - min_obs) - return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) - - def _split_node_into_feature_groups( - node: Node, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Splits node into features.""" - data = np.zeros(6) - distance = np.zeros(1) - agent_data = np.zeros(4) - - data[0] = node.dist_own_target_encountered - data[1] = node.dist_other_target_encountered - data[2] = node.dist_other_agent_encountered - data[3] = node.dist_potential_conflict - data[4] = node.dist_unusable_switch - data[5] = node.dist_to_next_branch - - distance[0] = node.dist_min_to_target - - agent_data[0] = node.num_agents_same_direction - agent_data[1] = node.num_agents_opposite_direction - agent_data[2] = node.num_agents_malfunctioning - agent_data[3] = node.speed_min_fractional + data, distance, agent_data = _split_node_into_feature_groups(node) + if not node.childs: return data, distance, agent_data - @typing.no_type_check - def _split_subtree_into_feature_groups( - node: Node, current_tree_depth: int, max_tree_depth: int - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Split subtree.""" - if node == -np.inf: - remaining_depth = max_tree_depth - current_tree_depth - # reference: - # https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure - num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) - return ( - [-np.inf] * num_remaining_nodes * 6, - [-np.inf] * num_remaining_nodes, - [-np.inf] * num_remaining_nodes * 4, - ) - - data, distance, agent_data = _split_node_into_feature_groups(node) - - if not node.childs: - return data, distance, agent_data + for direction in TreeObsForRailEnv.tree_explored_actions_char: + sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( + node.childs[direction], current_tree_depth + 1, max_tree_depth + ) + data = np.concatenate((data, sub_data)) + distance = np.concatenate((distance, sub_distance)) + agent_data = np.concatenate((agent_data, sub_agent_data)) - for direction in TreeObsForRailEnv.tree_explored_actions_char: - sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( - node.childs[direction], current_tree_depth + 1, max_tree_depth - ) - data = np.concatenate((data, sub_data)) - distance = np.concatenate((distance, sub_distance)) - agent_data = np.concatenate((agent_data, sub_agent_data)) + return data, distance, agent_data - return data, distance, agent_data - def split_tree_into_feature_groups( - tree: Node, max_tree_depth: int - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """This function splits the tree into three difference arrays.""" - data, distance, agent_data = _split_node_into_feature_groups(tree) +def split_tree_into_feature_groups( + tree: Node, max_tree_depth: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """This function splits the tree into three difference arrays.""" + data, distance, agent_data = _split_node_into_feature_groups(tree) - for direction in TreeObsForRailEnv.tree_explored_actions_char: - sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( - tree.childs[direction], 1, max_tree_depth - ) - data = np.concatenate((data, sub_data)) # type: ignore - distance = np.concatenate((distance, sub_distance)) # type: ignore - agent_data = np.concatenate((agent_data, sub_agent_data)) # type: ignore + for direction in TreeObsForRailEnv.tree_explored_actions_char: + sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( + tree.childs[direction], 1, max_tree_depth + ) + data = np.concatenate((data, sub_data)) + distance = np.concatenate((distance, sub_distance)) + agent_data = np.concatenate((agent_data, sub_agent_data)) - return data, distance, agent_data + return data, distance, agent_data - def normalize_observation( - observation: Node, tree_depth: int, observation_radius: int = 0 - ) -> np.ndarray: - """This function normalizes the observation used by the RL algorithm.""" - if observation is None: - return np.zeros( - 11 * sum(np.power(4, i) for i in range(tree_depth + 1)), - dtype=np.float32, - ) - data, distance, agent_data = split_tree_into_feature_groups( - observation, tree_depth - ) - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance, normalize_to_range=True) - agent_data = np.clip(agent_data, -1, 1) - normalized_obs = np.array( - np.concatenate( - (np.concatenate((data, distance)), agent_data) - ), # type:ignore - dtype=np.float32, +def normalize_observation( + observation: Node, tree_depth: int, observation_radius: int = 0 +) -> np.ndarray: + """This function normalizes the observation used by the RL algorithm.""" + if observation is None: + return np.zeros( + 11 * sum(np.power(4, i) for i in range(tree_depth + 1)), dtype=np.float32 ) - return normalized_obs + data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth) + + data = norm_obs_clip(data, fixed_radius=observation_radius) + distance = norm_obs_clip(distance, normalize_to_range=True) + agent_data = np.clip(agent_data, -1, 1) + normalized_obs = np.array( + np.concatenate((np.concatenate((data, distance)), agent_data)), dtype=np.float32 + ) + return normalized_obs \ No newline at end of file From 9dded6ab2344c03e1501b07a91bce3f95fe73365 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 15:18:18 +0200 Subject: [PATCH 14/56] Code clean up. --- .../feedforward/decentralised/run_madqn.py | 1 - .../run_madqn_configurable_epsilon.py | 1 - .../run_madqn_custom_lr_schedule.py | 1 - .../decentralised/run_madqn_lr_schedule.py | 1 - ...un_qmix.py => run_madqn_scale_trainers.py} | 29 +- .../feedforward/decentralised/run_vdn.py | 108 ------ .../run_dial.py => centralised/run_vdn.py} | 41 +-- .../recurrent/decentralised/run_madqn.py | 7 +- ...th_coms.py => run_madqn_scale_trainers.py} | 26 +- .../feedforward/decentralised/run_madqn.py | 11 +- .../flatland/recurrent/centralised/run_vdn.py | 119 ++++++ .../recurrent/decentralised/run_madqn.py | 22 +- .../decentralised/run_madqn.py | 21 +- .../smac/feedforward/decentralised/test.py | 26 -- .../run_qmix.py | 20 +- .../{decentralised => centralised}/run_vdn.py | 17 +- .../smac/recurrent/decentralised/run_madqn.py | 30 +- .../run_madqn_scale_trainers.py} | 61 ++-- mava/components/tf/modules/mixing/mixers.py | 58 ++- mava/systems/tf/dial/README.md | 10 - mava/systems/tf/dial/__init__.py | 19 - mava/systems/tf/dial/builder.py | 200 ----------- mava/systems/tf/dial/execution.py | 130 ------- mava/systems/tf/dial/networks.py | 69 ---- mava/systems/tf/dial/system.py | 340 ------------------ mava/systems/tf/dial/training.py | 233 ------------ mava/systems/tf/madqn/builder.py | 26 +- mava/systems/tf/madqn/execution.py | 62 ++-- mava/systems/tf/madqn/networks.py | 11 +- mava/systems/tf/madqn/system.py | 73 ++-- mava/systems/tf/madqn/training.py | 118 +++--- .../tf/value_decomposition/networks.py | 24 +- mava/systems/tf/value_decomposition/system.py | 88 ++--- .../tf/value_decomposition/training.py | 141 ++++---- mava/utils/environments/flatland_utils.py | 26 +- mava/utils/environments/smac_utils.py | 23 +- mava/utils/training_utils.py | 5 +- mava/wrappers/env_preprocess_wrappers.py | 118 +++--- mava/wrappers/flatland.py | 36 +- mava/wrappers/smac.py | 38 +- 40 files changed, 660 insertions(+), 1730 deletions(-) rename examples/debugging/simple_spread/feedforward/decentralised/{run_qmix.py => run_madqn_scale_trainers.py} (78%) delete mode 100644 examples/debugging/simple_spread/feedforward/decentralised/run_vdn.py rename examples/debugging/simple_spread/recurrent/{decentralised/run_dial.py => centralised/run_vdn.py} (73%) rename examples/debugging/simple_spread/recurrent/decentralised/{run_madqn_with_coms.py => run_madqn_scale_trainers.py} (84%) create mode 100644 examples/flatland/recurrent/centralised/run_vdn.py rename examples/petting_zoo/atari/pong/{feedforward => recurrent}/decentralised/run_madqn.py (85%) delete mode 100644 examples/smac/feedforward/decentralised/test.py rename examples/smac/recurrent/{decentralised => centralised}/run_qmix.py (90%) rename examples/smac/recurrent/{decentralised => centralised}/run_vdn.py (88%) rename examples/{debugging/switch/recurrent/decentralised/run_dial.py => smac/recurrent/decentralised/run_madqn_scale_trainers.py} (66%) delete mode 100644 mava/systems/tf/dial/README.md delete mode 100644 mava/systems/tf/dial/__init__.py delete mode 100644 mava/systems/tf/dial/builder.py delete mode 100644 mava/systems/tf/dial/execution.py delete mode 100644 mava/systems/tf/dial/networks.py delete mode 100644 mava/systems/tf/dial/system.py delete mode 100644 mava/systems/tf/dial/training.py diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn.py index 391410ce1..a6f8f5261 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn.py @@ -84,7 +84,6 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-4 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_configurable_epsilon.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_configurable_epsilon.py index 349f4d7cf..90b7ad9d8 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_configurable_epsilon.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_configurable_epsilon.py @@ -108,7 +108,6 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=2, exploration_scheduler_fn=exploration_scheduler_fn, - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py index 165e5bf68..f337a2cbb 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py @@ -107,7 +107,6 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-4 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=lr_start), checkpoint_subpath=checkpoint_dir, learning_rate_scheduler_fn=learning_rate_scheduler_fn, diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_lr_schedule.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_lr_schedule.py index 6bbdeb1ce..9b77a54cd 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_lr_schedule.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_lr_schedule.py @@ -97,7 +97,6 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-4 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=lr), checkpoint_subpath=checkpoint_dir, learning_rate_scheduler_fn=learning_rate_scheduler_fn, diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_qmix.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_scale_trainers.py similarity index 78% rename from examples/debugging/simple_spread/feedforward/decentralised/run_qmix.py rename to examples/debugging/simple_spread/feedforward/decentralised/run_madqn_scale_trainers.py index 78fcc1d52..f16200626 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_qmix.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_scale_trainers.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running QMIX on debug MPE environments.""" + +"""Example running MADQN on debug MPE environments.""" import functools from datetime import datetime from typing import Any @@ -22,9 +23,9 @@ import sonnet as snt from absl import app, flags -from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler -from mava.systems.tf import qmix -from mava.utils import lp_utils +from mava.components.tf.modules.exploration import LinearExplorationScheduler +from mava.systems.tf import madqn +from mava.utils import enums, lp_utils from mava.utils.environments import debugging_utils from mava.utils.loggers import logger_utils @@ -49,18 +50,18 @@ def main(_: Any) -> None: + # Environment. environment_factory = functools.partial( debugging_utils.make_environment, env_name=FLAGS.env_name, action_space=FLAGS.action_space, - return_state_info=True, ) # Networks. - network_factory = lp_utils.partial_kwargs(qmix.make_default_networks) + network_factory = lp_utils.partial_kwargs(madqn.make_default_networks) - # Checkpointer appends "Checkpoints" to checkpoint_dir. + # Checkpointer appends "Checkpoints" to checkpoint_dir checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" # Log every [log_every] seconds. @@ -74,17 +75,19 @@ def main(_: Any) -> None: time_delta=log_every, ) - # Distributed program. - program = qmix.QMIX( + # distributed program + program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=20000 + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-4 ), - max_replay_size=1000000, - optimizer=snt.optimizers.RMSProp(learning_rate=1e-4), + shared_weights=False, + trainer_networks=enums.Trainer.one_trainer_per_network, + network_sampling_setup=enums.NetworkSampler.fixed_agent_networks, + optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_vdn.py b/examples/debugging/simple_spread/feedforward/decentralised/run_vdn.py deleted file mode 100644 index ef19d3d3b..000000000 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_vdn.py +++ /dev/null @@ -1,108 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Example running VDN on debug MPE environments.""" -import functools -from datetime import datetime -from typing import Any - -import launchpad as lp -import sonnet as snt -from absl import app, flags - -from mava.components.tf.modules.exploration import ( - ExponentialExplorationTimestepScheduler, -) -from mava.systems.tf import vdn -from mava.utils import lp_utils -from mava.utils.environments import debugging_utils -from mava.utils.loggers import logger_utils - -FLAGS = flags.FLAGS -flags.DEFINE_string( - "env_name", - "simple_spread", - "Debugging environment name (str).", -) -flags.DEFINE_string( - "action_space", - "discrete", - "Environment action space type (str).", -) - -flags.DEFINE_string( - "mava_id", - str(datetime.now()), - "Experiment identifier that can be used to continue experiments.", -) -flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") - - -def main(_: Any) -> None: - # Environment. - environment_factory = functools.partial( - debugging_utils.make_environment, - env_name=FLAGS.env_name, - action_space=FLAGS.action_space, - return_state_info=True, - ) - - # Networks. - network_factory = lp_utils.partial_kwargs(vdn.make_default_networks) - - # Checkpointer appends "Checkpoints" to checkpoint_dir. - checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" - - # Log every [log_every] seconds. - log_every = 10 - logger_factory = functools.partial( - logger_utils.make_logger, - directory=FLAGS.base_dir, - to_terminal=True, - to_tensorboard=True, - time_stamp=FLAGS.mava_id, - time_delta=log_every, - ) - - # Distributed program. - program = vdn.VDN( - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - num_executors=1, - exploration_scheduler_fn=ExponentialExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=20000 - ), - max_replay_size=1000000, - optimizer=snt.optimizers.RMSProp(learning_rate=1e-4), - checkpoint_subpath=checkpoint_dir, - ).build() - - # Ensure only trainer runs on gpu, while other processes run on cpu. - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] - ) - - # Launch. - lp.launch( - program, - lp.LaunchType.LOCAL_MULTI_PROCESSING, - terminal="current_terminal", - local_resources=local_resources, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/examples/debugging/simple_spread/recurrent/decentralised/run_dial.py b/examples/debugging/simple_spread/recurrent/centralised/run_vdn.py similarity index 73% rename from examples/debugging/simple_spread/recurrent/decentralised/run_dial.py rename to examples/debugging/simple_spread/recurrent/centralised/run_vdn.py index a8f342b38..d83c13232 100644 --- a/examples/debugging/simple_spread/recurrent/decentralised/run_dial.py +++ b/examples/debugging/simple_spread/recurrent/centralised/run_vdn.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Example running VDN on debugging environment.""" -"""Example running Dial on debug MPE environments.""" import functools from datetime import datetime from typing import Any @@ -22,16 +22,12 @@ import sonnet as snt from absl import app, flags -from mava.components.tf.modules.communication.broadcasted import ( - BroadcastedCommunication, -) from mava.components.tf.modules.exploration.exploration_scheduling import ( LinearExplorationScheduler, ) -from mava.systems.tf import dial +from mava.systems.tf import value_decomposition from mava.utils import lp_utils -from mava.utils.enums import ArchitectureType -from mava.utils.environments import debugging_utils +from mava.utils.environments.debugging_utils import make_environment from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -57,17 +53,17 @@ def main(_: Any) -> None: # Environment. environment_factory = functools.partial( - debugging_utils.make_environment, + make_environment, env_name=FLAGS.env_name, action_space=FLAGS.action_space, ) # Networks. network_factory = lp_utils.partial_kwargs( - dial.make_default_networks, archecture_type=ArchitectureType.recurrent + value_decomposition.make_default_networks, ) - # Checkpointer appends "Checkpoints" to checkpoint_dir. + # Checkpointer appends "Checkpoints" to checkpoint_dir checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" # Log every [log_every] seconds. @@ -81,31 +77,32 @@ def main(_: Any) -> None: time_delta=log_every, ) - # Distributed program. - program = dial.DIAL( + # distributed program + program = value_decomposition.ValueDecomposition( environment_factory=environment_factory, network_factory=network_factory, + mixer="vdn", logger_factory=logger_factory, num_executors=1, - trainer_fn=dial.DIALSwitchTrainer, - executor_fn=dial.DIALSwitchExecutor, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=2.5e-4 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + ), + optimizer=snt.optimizers.RMSProp( + learning_rate=0.0005, epsilon=0.00001, decay=0.99 ), - communication_module=BroadcastedCommunication, - sequence_length=6, - optimizer=snt.optimizers.RMSProp(learning_rate=1e-4, momentum=0.95), checkpoint_subpath=checkpoint_dir, - n_step=1, batch_size=32, + max_gradient_norm=20.0, + min_replay_size=32, + max_replay_size=10000, + samples_per_insert=16, + evaluator_interval={"executor_episodes": 2}, ).build() - # Ensure only trainer runs on gpu, while other processes run on cpu. + # launch local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) - - # Launch. lp.launch( program, lp.LaunchType.LOCAL_MULTI_PROCESSING, diff --git a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py index 889534147..542c7fe94 100644 --- a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py +++ b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py @@ -47,6 +47,7 @@ str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) + flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") @@ -57,7 +58,6 @@ def main(_: Any) -> None: debugging_utils.make_environment, env_name=FLAGS.env_name, action_space=FLAGS.action_space, - num_agents=10 ) # Networks. @@ -90,9 +90,10 @@ def main(_: Any) -> None: ), optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, - trainer_fn=madqn.training.MADQNRecurrentTrainer, - executor_fn=madqn.execution.MADQNRecurrentExecutor, + trainer_fn=madqn.MADQNRecurrentTrainer, + executor_fn=madqn.MADQNRecurrentExecutor, max_replay_size=5000, + min_replay_size=32, batch_size=32, ).build() diff --git a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn_with_coms.py b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn_scale_trainers.py similarity index 84% rename from examples/debugging/simple_spread/recurrent/decentralised/run_madqn_with_coms.py rename to examples/debugging/simple_spread/recurrent/decentralised/run_madqn_scale_trainers.py index 2575cd445..e11aef8c7 100644 --- a/examples/debugging/simple_spread/recurrent/decentralised/run_madqn_with_coms.py +++ b/examples/debugging/simple_spread/recurrent/decentralised/run_madqn_scale_trainers.py @@ -23,13 +23,10 @@ import sonnet as snt from absl import app, flags -from mava.components.tf.modules.communication.broadcasted import ( - BroadcastedCommunication, -) from mava.components.tf.modules.exploration import LinearExplorationScheduler from mava.systems.tf import madqn -from mava.utils import lp_utils -from mava.utils.enums import ArchitectureType, Network +from mava.utils import enums, lp_utils +from mava.utils.enums import ArchitectureType from mava.utils.environments import debugging_utils from mava.utils.loggers import logger_utils @@ -64,10 +61,7 @@ def main(_: Any) -> None: # Networks. network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, - archecture_type=ArchitectureType.recurrent, - message_size=10, - network_type=Network.coms_network, + madqn.make_default_networks, architecture_type=ArchitectureType.recurrent ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -84,7 +78,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # Distributed program. + # distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, @@ -93,12 +87,16 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-4 ), + shared_weights=False, + trainer_networks=enums.Trainer.one_trainer_per_network, + network_sampling_setup=enums.NetworkSampler.fixed_agent_networks, + trainer_fn=madqn.MADQNRecurrentTrainer, + executor_fn=madqn.MADQNRecurrentExecutor, + max_replay_size=5000, + min_replay_size=32, + batch_size=32, optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, - trainer_fn=madqn.training.MADQNRecurrentCommTrainer, - executor_fn=madqn.execution.MADQNRecurrentCommExecutor, - batch_size=32, - communication_module=BroadcastedCommunication, ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/examples/flatland/feedforward/decentralised/run_madqn.py b/examples/flatland/feedforward/decentralised/run_madqn.py index b159f4721..6bc4cc710 100644 --- a/examples/flatland/feedforward/decentralised/run_madqn.py +++ b/examples/flatland/feedforward/decentralised/run_madqn.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Example running feedforward MADQN on Flatland""" import functools from datetime import datetime @@ -27,7 +27,7 @@ ) from mava.systems.tf import madqn from mava.utils import lp_utils -from mava.utils.environments.flatland_utils import flatland_env_factory +from mava.utils.environments.flatland_utils import make_environment from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -60,9 +60,7 @@ def main(_: Any) -> None: # Environment. - environment_factory = functools.partial( - flatland_env_factory, env_config=flatland_env_config, include_agent_info=False - ) + environment_factory = functools.partial(make_environment, **flatland_env_config) # Networks. network_factory = lp_utils.partial_kwargs(madqn.make_default_networks) @@ -88,9 +86,8 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-4 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() diff --git a/examples/flatland/recurrent/centralised/run_vdn.py b/examples/flatland/recurrent/centralised/run_vdn.py new file mode 100644 index 000000000..685579f6d --- /dev/null +++ b/examples/flatland/recurrent/centralised/run_vdn.py @@ -0,0 +1,119 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from datetime import datetime +from typing import Any, Dict + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.components.tf.modules.exploration.exploration_scheduling import ( + LinearExplorationScheduler, +) +from mava.systems.tf import value_decomposition +from mava.utils import lp_utils +from mava.utils.environments.flatland_utils import make_environment +from mava.utils.loggers import logger_utils + +"""Example running VDN on Flatland.""" + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + +# flatland environment config +env_config: Dict = { + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, +} + + +def main(_: Any) -> None: + + # Environment. + environment_factory = functools.partial(make_environment, **env_config) + + # Networks. + network_factory = lp_utils.partial_kwargs( + value_decomposition.make_default_networks, + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # distributed program + program = value_decomposition.ValueDecomposition( + environment_factory=environment_factory, + network_factory=network_factory, + mixer="vdn", + logger_factory=logger_factory, + num_executors=1, + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + ), + optimizer=snt.optimizers.RMSProp( + learning_rate=0.0005, epsilon=0.00001, decay=0.99 + ), + checkpoint_subpath=checkpoint_dir, + batch_size=32, + max_gradient_norm=20.0, + min_replay_size=32, + max_replay_size=10000, + samples_per_insert=16, + evaluator_interval={"executor_episodes": 2}, + ).build() + + # launch + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/flatland/recurrent/decentralised/run_madqn.py b/examples/flatland/recurrent/decentralised/run_madqn.py index e30090327..d6f7de0c8 100644 --- a/examples/flatland/recurrent/decentralised/run_madqn.py +++ b/examples/flatland/recurrent/decentralised/run_madqn.py @@ -27,10 +27,9 @@ ) from mava.systems.tf import madqn from mava.utils import lp_utils +from mava.utils.enums import ArchitectureType from mava.utils.environments.flatland_utils import make_environment from mava.utils.loggers import logger_utils -from mava.utils.enums import ArchitectureType - FLAGS = flags.FLAGS @@ -62,14 +61,11 @@ def main(_: Any) -> None: # Environment. - environment_factory = functools.partial( - make_environment, **env_config - ) + environment_factory = functools.partial(make_environment, **env_config) # Networks. network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, - architecture_type=ArchitectureType.recurrent + madqn.make_default_networks, architecture_type=ArchitectureType.recurrent ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -97,18 +93,14 @@ def main(_: Any) -> None: ), optimizer=snt.optimizers.Adam(learning_rate=1e-4), batch_size=32, - executor_variable_update_period=200, - target_update_period=200, + samples_per_insert=16, max_gradient_norm=20.0, - sequence_length=70, - period=70, min_replay_size=32, - max_replay_size=5000, - trainer_fn=madqn.training.MADQNRecurrentTrainer, - executor_fn=madqn.execution.MADQNRecurrentExecutor, + max_replay_size=10000, + trainer_fn=madqn.MADQNRecurrentTrainer, + executor_fn=madqn.MADQNRecurrentExecutor, checkpoint_subpath=checkpoint_dir, evaluator_interval={"executor_episodes": 2}, - termination_condition={"executor_steps": 3_000_000} ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/examples/petting_zoo/atari/pong/feedforward/decentralised/run_madqn.py b/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py similarity index 85% rename from examples/petting_zoo/atari/pong/feedforward/decentralised/run_madqn.py rename to examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py index cc8f12dbb..a7ac7312a 100644 --- a/examples/petting_zoo/atari/pong/feedforward/decentralised/run_madqn.py +++ b/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running MADQN on debug Atari Pong.""" +"""Example running MADQN on Atari Pong.""" + import functools from datetime import datetime from typing import Any @@ -24,8 +25,7 @@ from mava.components.tf.modules.exploration import LinearExplorationScheduler from mava.systems.tf import madqn -from mava.utils import lp_utils -from mava.utils.enums import Network +from mava.utils import enums, lp_utils from mava.utils.environments import pettingzoo_utils from mava.utils.loggers import logger_utils @@ -60,7 +60,9 @@ def main(_: Any) -> None: # Networks. network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, network_type=Network.mlp + madqn.make_default_networks, + architecture_type=enums.ArchitectureType.recurrent, + atari_torso_observation_network=True, ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -84,10 +86,17 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-4 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-6 ), - importance_sampling_exponent=0.2, + shared_weights=False, + batch_size=32, + max_replay_size=10000, + samples_per_insert=16, + min_replay_size=32, optimizer=snt.optimizers.Adam(learning_rate=1e-4), + executor_fn=madqn.MADQNRecurrentExecutor, + trainer_fn=madqn.MADQNRecurrentTrainer, + evaluator_interval={"executor_episodes": 2}, checkpoint_subpath=checkpoint_dir, ).build() diff --git a/examples/smac/feedforward/decentralised/test.py b/examples/smac/feedforward/decentralised/test.py deleted file mode 100644 index f5cf84d45..000000000 --- a/examples/smac/feedforward/decentralised/test.py +++ /dev/null @@ -1,26 +0,0 @@ -from smac.env import StarCraft2Env -from mava.wrappers import SMACWrapper -from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation -from mava.wrappers.env_preprocess_wrappers import ConcatPrevActionToObservation -import numpy as np - -env = StarCraft2Env(map_name="3m") - -env = SMACWrapper(env) - -env = ConcatAgentIdToObservation(env) - -env = ConcatPrevActionToObservation(env) - -spec = env.action_spec() -# for agent in spec: -# print(spec[agent].num_values) -spec = env.observation_spec() - -res = env.reset() - -actions = {"agent_0": 1, "agent_1": 2, "agent_2": 3} - -res = env.step(actions) - -print("Done") \ No newline at end of file diff --git a/examples/smac/recurrent/decentralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py similarity index 90% rename from examples/smac/recurrent/decentralised/run_qmix.py rename to examples/smac/recurrent/centralised/run_qmix.py index 4293672ed..f0e42dbcb 100644 --- a/examples/smac/recurrent/decentralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -27,16 +27,13 @@ ) from mava.systems.tf import value_decomposition from mava.utils import lp_utils -from mava.utils.loggers import logger_utils from mava.utils.environments.smac_utils import make_environment - -SEQUENCE_LENGTH = 120 -MAP_NAME = "2s3z" +from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - MAP_NAME, + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -45,15 +42,14 @@ str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) + flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") def main(_: Any) -> None: """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" # environment - environment_factory = functools.partial( - make_environment, map_name=FLAGS.map_name - ) + environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. network_factory = lp_utils.partial_kwargs( @@ -92,14 +88,10 @@ def main(_: Any) -> None: executor_variable_update_period=200, target_update_period=200, max_gradient_norm=20.0, - sequence_length=20, - period=10, min_replay_size=32, - max_replay_size=10_000, - samples_per_insert=32, + max_replay_size=10000, + samples_per_insert=16, evaluator_interval={"executor_episodes": 2}, - termination_condition={"executor_steps": 3_000_000} - ).build() # launch diff --git a/examples/smac/recurrent/decentralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py similarity index 88% rename from examples/smac/recurrent/decentralised/run_vdn.py rename to examples/smac/recurrent/centralised/run_vdn.py index 20458240d..00afb0c03 100644 --- a/examples/smac/recurrent/decentralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -30,13 +30,10 @@ from mava.utils.environments.smac_utils import make_environment from mava.utils.loggers import logger_utils -SEQUENCE_LENGTH = 60 -MAP_NAME = "3m" - FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - MAP_NAME, + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -45,6 +42,7 @@ str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) + flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") @@ -80,22 +78,17 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=8e-6 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 ), checkpoint_subpath=checkpoint_dir, batch_size=32, - executor_variable_update_period=200, - target_update_period=100, max_gradient_norm=20.0, - sequence_length=SEQUENCE_LENGTH, - period=SEQUENCE_LENGTH, min_replay_size=32, - max_replay_size=5000, - samples_per_insert=1, - termination_condition={"executor_steps": 3_000_000}, + max_replay_size=10000, + samples_per_insert=16, evaluator_interval={"executor_episodes": 2}, ).build() diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 3d55dcc56..44c5c1382 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -16,30 +16,20 @@ import functools from datetime import datetime -from typing import Any, Dict, Mapping, Sequence, Union +from typing import Any import launchpad as lp import sonnet as snt -import tensorflow as tf from absl import app, flags -from acme import types -from mava import specs as mava_specs -from mava.components.tf import networks from mava.components.tf.modules.exploration.exploration_scheduling import ( LinearExplorationScheduler, ) -from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy from mava.systems.tf import madqn from mava.utils import lp_utils from mava.utils.enums import ArchitectureType -from mava.utils.environments import pettingzoo_utils from mava.utils.environments.smac_utils import make_environment from mava.utils.loggers import logger_utils -from mava.wrappers.env_preprocess_wrappers import ( - ConcatAgentIdToObservation, - ConcatPrevActionToObservation, -) SEQUENCE_LENGTH = 60 MAP_NAME = "3m" @@ -56,6 +46,7 @@ str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) + flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") @@ -83,14 +74,14 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=8e-6 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -98,17 +89,14 @@ def main(_: Any) -> None: checkpoint_subpath=checkpoint_dir, batch_size=32, executor_variable_update_period=200, - target_update_period=100, + target_update_period=200, max_gradient_norm=20.0, - sequence_length=SEQUENCE_LENGTH, - period=SEQUENCE_LENGTH, min_replay_size=32, - max_replay_size=5000, - samples_per_insert=1, + max_replay_size=10000, + samples_per_insert=16, evaluator_interval={"executor_episodes": 2}, - termination_condition={"executor_steps": 3_000_000}, - trainer_fn=madqn.training.MADQNRecurrentTrainer, - executor_fn=madqn.execution.MADQNRecurrentExecutor, + trainer_fn=madqn.MADQNRecurrentTrainer, + executor_fn=madqn.MADQNRecurrentExecutor, ).build() # launch diff --git a/examples/debugging/switch/recurrent/decentralised/run_dial.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py similarity index 66% rename from examples/debugging/switch/recurrent/decentralised/run_dial.py rename to examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index 1373831d3..e759c08c6 100644 --- a/examples/debugging/switch/recurrent/decentralised/run_dial.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running Dial on debug Switch environments.""" + +"""Example running MADQN on SMAC with multiple trainers.""" import functools from datetime import datetime from typing import Any @@ -22,34 +23,26 @@ import sonnet as snt from absl import app, flags -from mava.components.tf.modules.communication.broadcasted import ( - BroadcastedCommunication, -) -from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationScheduler, -) -from mava.systems.tf import dial -from mava.utils import lp_utils +from mava.components.tf.modules.exploration import LinearExplorationScheduler +from mava.systems.tf import madqn +from mava.utils import enums, lp_utils from mava.utils.enums import ArchitectureType -from mava.utils.environments import debugging_utils +from mava.utils.environments import smac_utils from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS flags.DEFINE_string( "env_name", - "switch", - "Debugging environment name (str).", -) -flags.DEFINE_string( - "action_space", - "discrete", - "Environment action space type (str).", + "3m", + "SMAC map name.", ) + flags.DEFINE_string( "mava_id", str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) + flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") @@ -57,17 +50,16 @@ def main(_: Any) -> None: # Environment. environment_factory = functools.partial( - debugging_utils.make_environment, - env_name=FLAGS.env_name, - action_space=FLAGS.action_space, + smac_utils.make_environment, + map_name=FLAGS.env_name, ) # Networks. network_factory = lp_utils.partial_kwargs( - dial.make_default_networks, archecture_type=ArchitectureType.recurrent + madqn.make_default_networks, architecture_type=ArchitectureType.recurrent ) - # Checkpointer appends "Checkpoints" to checkpoint_dir. + # Checkpointer appends "Checkpoints" to checkpoint_dir checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" # Log every [log_every] seconds. @@ -81,23 +73,28 @@ def main(_: Any) -> None: time_delta=log_every, ) - # Distributed program. - program = dial.DIAL( + # distributed program + program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - trainer_fn=dial.DIALSwitchTrainer, - executor_fn=dial.DIALSwitchExecutor, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=2.5e-4 + epsilon_start=1.0, + epsilon_min=0.05, + epsilon_decay=4e-5, ), - communication_module=BroadcastedCommunication, - sequence_length=6, - optimizer=snt.optimizers.RMSProp(learning_rate=1e-4, momentum=0.95), - checkpoint_subpath=checkpoint_dir, - n_step=1, + shared_weights=False, + trainer_networks=enums.Trainer.one_trainer_per_network, + network_sampling_setup=enums.NetworkSampler.fixed_agent_networks, + trainer_fn=madqn.MADQNRecurrentTrainer, + executor_fn=madqn.MADQNRecurrentExecutor, + max_replay_size=5000, + min_replay_size=32, batch_size=32, + evaluator_interval={"executor_episodes": 2}, + optimizer=snt.optimizers.Adam(learning_rate=1e-4), + checkpoint_subpath=checkpoint_dir, ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/mava/components/tf/modules/mixing/mixers.py b/mava/components/tf/modules/mixing/mixers.py index 1e191859a..7039e3f7c 100644 --- a/mava/components/tf/modules/mixing/mixers.py +++ b/mava/components/tf/modules/mixing/mixers.py @@ -1,50 +1,51 @@ import sonnet as snt import tensorflow as tf + @snt.allow_empty_variables class BaseMixer(snt.Module): - """Base mixing class. - - Base mixer should take in agent q-values and environment global state tensors. + """Base mixing class. + + Base mixer should take in agent q-values and environment global state tensors. """ - def __init__(self): + + def __init__(self) -> None: super().__init__() - def __call__(self, agent_qs: tf.Tensor , states: tf.Tensor): + def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor) -> tf.Tensor: return agent_qs + class VDN(BaseMixer): """VDN mixing network.""" - def __init__(self): + def __init__(self) -> None: super().__init__() - - def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor): + + def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor) -> tf.Tensor: return tf.reduce_sum(agent_qs, axis=-1, keepdims=True) """Initialize VDN class Args: agent_qs: Tensor containing the q-values of actions chosen by agents - states: Tensor containing global environment state. + states: Tensor containing global environment state. Returns: Tensor with total q-value. """ + class QMIX(BaseMixer): """QMIX mixing network.""" def __init__( - self, - num_agents: int, - embed_dim: int = 32, - hypernet_embed: int = 64 - ): + self, num_agents: int, embed_dim: int = 32, hypernet_embed: int = 64 + ) -> None: """Inialize QMIX mixing network - - Args: + + Args: num_agents: Number of agents in the enviroment state_dim: Dimensions of the global environment state - embed_dim: TODO (Ruan): Cluade please add + embed_dim: TODO (Ruan): Cluade please add hypernet_embed: TODO (Ruan): Claude Please add """ @@ -52,37 +53,26 @@ def __init__( self.num_agents = num_agents self.embed_dim = embed_dim self.hypernet_embed = hypernet_embed - self.hyper_w_1 = snt.Sequential( [ snt.Linear(self.hypernet_embed), tf.nn.relu, - snt.Linear(self.embed_dim * self.num_agents) + snt.Linear(self.embed_dim * self.num_agents), ] ) self.hyper_w_final = snt.Sequential( - [ - snt.Linear(self.hypernet_embed), - tf.nn.relu, - snt.Linear(self.embed_dim) - ] + [snt.Linear(self.hypernet_embed), tf.nn.relu, snt.Linear(self.embed_dim)] ) # State dependent bias for hidden layer self.hyper_b_1 = snt.Linear(self.embed_dim) # V(s) instead of a bias for the last layers - self.V = snt.Sequential( - [ - snt.Linear(self.embed_dim), - tf.nn.relu, - snt.Linear(1) - ] - ) + self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)]) - def __call__(self, agent_qs, states): + def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor) -> tf.Tensor: bs = agent_qs.shape[1] state_dim = states.shape[-1] @@ -107,6 +97,6 @@ def __call__(self, agent_qs, states): y = tf.matmul(hidden, w_final) + v # Reshape and return - q_tot = tf.reshape(y, (-1, bs, 1)) # [T, B, 1] + q_tot = tf.reshape(y, (-1, bs, 1)) # [T, B, 1] - return q_tot \ No newline at end of file + return q_tot diff --git a/mava/systems/tf/dial/README.md b/mava/systems/tf/dial/README.md deleted file mode 100644 index 942ab4b54..000000000 --- a/mava/systems/tf/dial/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Differentiable inter-agent learning (DIAL) - -A common component of MARL systems is agent communication. Mava supports general purpose components specifically for implementing systems with communication. This system is an example of an implementation of differentiable inter-agent learning (DIAL) based on the work by [Foerster et al. (2016)][Foerster et al., 2016]. The trainer is implemented to work specifically for the switch game environment shown below. - -

- - -

- -[Foerster et al., 2016]: https://arxiv.org/abs/1605.06676 diff --git a/mava/systems/tf/dial/__init__.py b/mava/systems/tf/dial/__init__.py deleted file mode 100644 index b149c4b07..000000000 --- a/mava/systems/tf/dial/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mava.systems.tf.dial.execution import DIALSwitchExecutor -from mava.systems.tf.dial.networks import make_default_networks -from mava.systems.tf.dial.system import DIAL -from mava.systems.tf.dial.training import DIALSwitchTrainer diff --git a/mava/systems/tf/dial/builder.py b/mava/systems/tf/dial/builder.py deleted file mode 100644 index 30c53779f..000000000 --- a/mava/systems/tf/dial/builder.py +++ /dev/null @@ -1,200 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DIAL system builder implementation.""" - -import dataclasses -from typing import Any, Dict, Iterator, Optional, Type, Union - -import reverb -import sonnet as snt -from acme.utils import counting - -from mava import adders, core, types -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.components.tf.modules.exploration.exploration_scheduling import ( - BaseExplorationScheduler, - BaseExplorationTimestepScheduler, - ConstantScheduler, -) -from mava.components.tf.modules.stabilising import FingerPrintStabalisation -from mava.systems.tf.madqn import execution, training -from mava.systems.tf.madqn.builder import MADQNBuilder, MADQNConfig - - -@dataclasses.dataclass -class DIALConfig(MADQNConfig): - """Configuration options for the DIAL system. - - Args: - environment_spec: description of the action and observation spaces etc. for - each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - target_update_period: number of learner steps to perform before updating - the target networks. - executor_variable_update_period: the rate at which executors sync their - paramters with the trainer. - max_gradient_norm: value to specify the maximum clipping value for the gradient - norm during optimization. - min_replay_size: minimum replay size before updating. - max_replay_size: maximum replay size. - samples_per_insert: number of samples to take from replay for every insert - that is made. - prefetch_size: size to prefetch from replay. - batch_size: batch size for updates. - n_step: number of steps to include prior to boostrapping. - sequence_length: recurrent sequence rollout length. - period: consecutive starting points for overlapping rollouts across a sequence. - discount: discount to use for TD updates. - checkpoint: boolean to indicate whether to checkpoint models. - optimizer: type of optimizer to use for updating the parameters of models. - replay_table_name: string indicating what name to give the replay table. - checkpoint_subpath: subdirectory specifying where to store checkpoints. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - -class DIALBuilder(MADQNBuilder): - """Builder for DIAL which constructs individual components of the system.""" - - def __init__( - self, - config: DIALConfig, - trainer_fn: Type[ - training.MADQNRecurrentCommTrainer - ] = training.MADQNRecurrentCommTrainer, - executor_fn: Type[core.Executor] = execution.MADQNRecurrentCommExecutor, - extra_specs: Dict[str, Any] = {}, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, - ): - """Initialise the system. - - Args: - config (DIALConfig): system configuration specifying hyperparameters and - additional information for constructing the system. - trainer_fn (Type[ training.MADQNRecurrentCommTrainer ], optional): - Trainer function, of a correpsonding type to work with the selected - system architecture. Defaults to training.MADQNRecurrentCommTrainer. - executor_fn (Type[core.Executor], optional): Executor function, of a - corresponding type to work with the selected system architecture. - Defaults to execution.MADQNRecurrentCommExecutor. - extra_specs (Dict[str, Any], optional): defines the specifications of extra - information used by the system. Defaults to {}. - replay_stabilisation_fn (Optional[Type[FingerPrintStabalisation]], - optional): optional function to stabilise experience replay. Defaults - to None. - """ - super(DIALBuilder, self).__init__( - config=config, - trainer_fn=trainer_fn, - executor_fn=executor_fn, - extra_specs=extra_specs, - replay_stabilisation_fn=replay_stabilisation_fn, - ) - - def make_executor( # type: ignore[override] - self, - q_networks: Dict[str, snt.Module], - exploration_schedules: Dict[ - str, - Union[ - BaseExplorationTimestepScheduler, - BaseExplorationScheduler, - ConstantScheduler, - ], - ], - action_selectors: Dict[str, Any], - communication_module: BaseCommunicationModule, - adder: Optional[adders.ParallelAdder] = None, - variable_source: Optional[core.VariableSource] = None, - trainer: Optional[training.MADQNRecurrentCommTrainer] = None, - evaluator: bool = False, - seed: Optional[int] = None, - ) -> core.Executor: - """Create an executor instance. - - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - exploration_schedules : epsilon decay schedule per agent. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. - adder (Optional[adders.ParallelAdder], optional): adder to send data to - a replay buffer. Defaults to None. - variable_source (Optional[core.VariableSource], optional): variables server. - Defaults to None. - trainer (Optional[training.MADQNRecurrentCommTrainer], optional): - system trainer. Defaults to None. - evaluator (bool, optional): boolean indicator if the executor is used for - for evaluation only. Defaults to False. - seed: seed for reproducible sampling. - - Returns: - core.Executor: system executor, a collection of agents making up the part - of the system generating data by interacting the environment. - """ - - return super().make_executor( - q_networks=q_networks, - exploration_schedules=exploration_schedules, - action_selectors=action_selectors, - communication_module=communication_module, - adder=adder, - variable_source=variable_source, - trainer=trainer, - evaluator=evaluator, - seed=seed, - ) - - def make_trainer( # type: ignore[override] - self, - networks: Dict[str, Dict[str, snt.Module]], - dataset: Iterator[reverb.ReplaySample], - communication_module: BaseCommunicationModule, # type: ignore - counter: Optional[counting.Counter] = None, - logger: Optional[types.NestedLogger] = None, - replay_client: Optional[reverb.TFClient] = None, - ) -> core.Trainer: - """Create a trainer instance. - - Args: - networks (Dict[str, Dict[str, snt.Module]]): system networks. - dataset (Iterator[reverb.ReplaySample]): dataset iterator to feed data to - the trainer networks. - communication_module (BaseCommunicationModule): module to enable - agent communication. - counter (Optional[counting.Counter], optional): a Counter which allows for - recording of counts, e.g. trainer steps. Defaults to None. - logger (Optional[types.NestedLogger], optional): Logger object for logging - metadata.. Defaults to None. - replay_client (reverb.TFClient): Used for importance sampling. - Not implemented yet. - - Returns: - core.Trainer: system trainer, that uses the collected data from the - executors to update the parameters of the agent networks in the system. - """ - return super().make_trainer( - networks=networks, - dataset=dataset, - communication_module=communication_module, - counter=counter, - logger=logger, - replay_client=replay_client, - ) diff --git a/mava/systems/tf/dial/execution.py b/mava/systems/tf/dial/execution.py deleted file mode 100644 index 5c1480d7e..000000000 --- a/mava/systems/tf/dial/execution.py +++ /dev/null @@ -1,130 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DIAL system executor implementation.""" - -from typing import Any, Dict, Optional - -import sonnet as snt -import tensorflow as tf -from acme import types -from acme.tf import variable_utils as tf2_variable_utils - -from mava import adders -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.execution import MADQNRecurrentCommExecutor -from mava.systems.tf.madqn.training import MADQNTrainer - - -class DIALSwitchExecutor(MADQNRecurrentCommExecutor): - """DIAL executor. - An executor based on a recurrent communicating policy for each agent in the system. - Note: this executor is specific to switch game env. - """ - - def __init__( - self, - q_networks: Dict[str, snt.Module], - action_selectors: Dict[str, snt.Module], - communication_module: BaseCommunicationModule, - agent_net_keys: Dict[str, str], - adder: Optional[adders.ParallelAdder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - store_recurrent_state: bool = True, - trainer: MADQNTrainer = None, - fingerprint: bool = False, - evaluator: bool = False, - interval: Optional[dict] = None, - ): - """Initialise the system executor - - Args: - q_networks (Dict[str, snt.Module]): q-value networks for each agent in the - system. - action_selectors (Dict[str, Any]): policy action selector method, e.g. - epsilon greedy. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - adder (Optional[adders.ParallelAdder], optional): adder which sends data - to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. - store_recurrent_state (bool, optional): boolean to store the recurrent - network hidden state. Defaults to True. - trainer (MADQNTrainer, optional): system trainer. Defaults to None. - fingerprint (bool, optional): whether to use fingerprint stabilisation to - stabilise experience replay. Defaults to False. - evaluator (bool, optional): whether the executor will be used for - evaluation. Defaults to False. - interval: interval that evaluations are run at. - """ - - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._q_networks = q_networks - self._policy_networks = q_networks - self._communication_module = communication_module - self._action_selectors = action_selectors - self._store_recurrent_state = store_recurrent_state - self._trainer = trainer - self._agent_net_keys = agent_net_keys - - self._states: Dict[str, Any] = {} - self._messages: Dict[str, Any] = {} - - self._evaluator = evaluator - self._interval = interval - - @tf.function - def _policy( - self, - agent: str, - observation: types.NestedTensor, - state: types.NestedTensor, - message: types.NestedTensor, - legal_actions: types.NestedTensor, - ) -> types.NestedTensor: - """Agent specific policy function - - Args: - agent (str): agent id - observation (types.NestedTensor): observation tensor received from the - environment. - state (types.NestedTensor): Recurrent network state. - message (types.NestedTensor): received agent messsage. - legal_actions (types.NestedTensor): actions allowed to be taken at the - current observation. - - Returns: - types.NestedTensor: action, message and new recurrent hidden state - """ - - (action, m_values), new_state = super()._policy( - agent, - observation, - state, - message, - legal_actions, - ) - - # Mask message if obs[0] == 1. - # Note: this is specific to switch env - if observation[0] == 0: - m_values = tf.zeros_like(m_values) - - return (action, m_values), new_state diff --git a/mava/systems/tf/dial/networks.py b/mava/systems/tf/dial/networks.py deleted file mode 100644 index 4f24e1c71..000000000 --- a/mava/systems/tf/dial/networks.py +++ /dev/null @@ -1,69 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, Mapping, Optional - -from acme import types - -from mava import specs as mava_specs -from mava.systems.tf.madqn.networks import ( - make_default_networks as make_default_networks_madqn, -) -from mava.utils.enums import ArchitectureType, Network - - -def make_default_networks( - environment_spec: mava_specs.MAEnvironmentSpec, - agent_net_keys: Dict[str, str], - message_size: int = 1, - archecture_type: ArchitectureType = ArchitectureType.recurrent, - network_type: Network = Network.coms_network, - fingerprints: bool = False, - seed: Optional[int] = None, -) -> Mapping[str, types.TensorTransformation]: - """Default networks for dial. - - Args: - environment_spec (mava_specs.MAEnvironmentSpec): description of the action and - observation spaces etc. for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - message_size (int, optional): size of message passed. Defaults to 1. - archecture_type (ArchitectureType, optional): archecture used for - agent networks. Can be feedforward or recurrent. - Defaults to ArchitectureType.recurrent. - network_type (Network, optional): Agent network type. - Can be mlp, atari_dqn_network or coms_network. - Defaults to Network.coms_network. - fingerprints (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - seed (int, optional): random seed for network initialization. - - Returns: - Mapping[str, types.TensorTransformation]: returned agent networks. - """ - - assert ( - archecture_type == ArchitectureType.recurrent - ), "Dial currently only supports recurrent architectures." - - return make_default_networks_madqn( - environment_spec=environment_spec, - agent_net_keys=agent_net_keys, - archecture_type=archecture_type, - network_type=network_type, - fingerprints=fingerprints, - message_size=message_size, - seed=seed, - ) diff --git a/mava/systems/tf/dial/system.py b/mava/systems/tf/dial/system.py deleted file mode 100644 index 84ccbe83e..000000000 --- a/mava/systems/tf/dial/system.py +++ /dev/null @@ -1,340 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DIAL system implementation.""" - -from typing import Any, Callable, Dict, Optional, Type, Union - -import acme -import dm_env -import reverb -import sonnet as snt -from acme import specs as acme_specs -from acme.utils import counting - -import mava -from mava import core -from mava import specs as mava_specs -from mava.components.tf.architectures import DecentralisedValueActor -from mava.components.tf.modules.communication import ( - BaseCommunicationModule, - BroadcastedCommunication, -) -from mava.components.tf.modules.stabilising import FingerPrintStabalisation -from mava.environment_loop import ParallelEnvironmentLoop -from mava.systems.tf import executors -from mava.systems.tf.dial import builder -from mava.systems.tf.dial.execution import DIALSwitchExecutor -from mava.systems.tf.dial.training import DIALSwitchTrainer -from mava.systems.tf.madqn import training -from mava.systems.tf.madqn.system import MADQN -from mava.types import EpsilonScheduler -from mava.utils.loggers import MavaLogger - - -class DIAL(MADQN): - """DIAL system.""" - - def __init__( - self, - environment_factory: Callable[[bool], dm_env.Environment], - network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], - exploration_scheduler_fn: Union[ - EpsilonScheduler, - Dict[str, EpsilonScheduler], - Dict[str, Dict[str, EpsilonScheduler]], - ], - logger_factory: Callable[[str], MavaLogger] = None, - architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, - trainer_fn: Type[training.MADQNRecurrentCommTrainer] = DIALSwitchTrainer, - communication_module: Type[BaseCommunicationModule] = BroadcastedCommunication, - executor_fn: Type[core.Executor] = DIALSwitchExecutor, - replay_stabilisation_fn: Optional[Type[FingerPrintStabalisation]] = None, - num_executors: int = 1, - num_caches: int = 0, - environment_spec: mava_specs.MAEnvironmentSpec = None, - shared_weights: bool = True, - agent_net_keys: Dict[str, str] = {}, - batch_size: int = 256, - prefetch_size: int = 4, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - samples_per_insert: Optional[float] = 4.0, - n_step: int = 5, - sequence_length: int = 20, - period: int = 20, - importance_sampling_exponent: Optional[float] = None, - max_priority_weight: float = 0.9, - max_gradient_norm: float = None, - discount: float = 1, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( - learning_rate=1e-4 - ), - target_update_period: int = 100, - executor_variable_update_period: int = 1000, - max_executor_steps: int = None, - checkpoint: bool = False, - checkpoint_subpath: str = "~/mava/", - checkpoint_minute_interval: int = 5, - logger_config: Dict = {}, - train_loop_fn: Callable = ParallelEnvironmentLoop, - eval_loop_fn: Callable = ParallelEnvironmentLoop, - train_loop_fn_kwargs: Dict = {}, - eval_loop_fn_kwargs: Dict = {}, - evaluator_interval: Optional[dict] = None, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - seed: Optional[int] = None, - ): - """Initialise the system - - Args: - environment_factory (Callable[[bool], dm_env.Environment]): function to - instantiate an environment. - network_factory (Callable[[acme_specs.BoundedArray], - Dict[str, snt.Module]]): function to instantiate system networks. - logger_factory (Callable[[str], MavaLogger], optional): function to - instantiate a system logger. Defaults to None. - architecture (Type[DecentralisedValueActor], optional): system architecture, - e.g. decentralised or centralised. Defaults to DecentralisedValueActor. - trainer_fn (Type[ training.MADQNRecurrentCommTrainer ], optional): - training type associated with executor and architecture, e.g. - centralised training. Defaults to training.MADQNRecurrentCommTrainer. - communication_module (Type[BaseCommunicationModule], optional): module for - enabling communication protocols between agents. Defaults to - BroadcastedCommunication. - executor_fn (Type[core.Executor], optional): executor type, e.g. - feedforward or recurrent. Defaults to - execution.MADQNFeedForwardExecutor. - exploration_scheduler_fn (Type[ LinearExplorationScheduler ], optional): - function specifying a decaying scheduler for epsilon exploration. - See mava/systems/tf/madqn/system.py for details. - replay_stabilisation_fn (Optional[Type[FingerPrintStabalisation]], - optional): replay buffer stabilisaiton function, e.g. fingerprints. - Defaults to None. - num_executors (int, optional): number of executor processes to run in - parallel. Defaults to 1. - num_caches (int, optional): number of trainer node caches. Defaults to 0. - environment_spec (mava_specs.MAEnvironmentSpec, optional): description of - the action, observation spaces etc. for each agent in the system. - Defaults to None. - shared_weights (bool, optional): whether agents should share weights or not. - When agent_net_keys are provided the value of shared_weights is ignored. - Defaults to True. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - batch_size (int, optional): sample batch size for updates. Defaults to 256. - prefetch_size (int, optional): size to prefetch from replay. Defaults to 4. - min_replay_size (int, optional): minimum replay size before updating. - Defaults to 1000. - max_replay_size (int, optional): maximum replay size. Defaults to 1000000. - samples_per_insert (Optional[float], optional): number of samples to take - from replay for every insert that is made. Defaults to 4.0. - n_step (int, optional): number of steps to include prior to boostrapping. - Defaults to 5. - sequence_length (int, optional): recurrent sequence rollout length. - Defaults to 6. - period (int, optional): consecutive starting points for overlapping - rollouts across a sequence. Defaults to 20. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - discount (float, optional): discount factor to use for TD updates. - Defaults to 1. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]], optional): - type of optimizer to use to update network parameters. Defaults to - snt.optimizers.Adam( learning_rate=1e-4 ). - target_update_period (int, optional): number of steps before target - networks are updated. Defaults to 100. - executor_variable_update_period (int, optional): number of steps before - updating executor variables from the variable source. Defaults to 1000. - max_executor_steps (int, optional): maximum number of steps and executor - can in an episode. Defaults to None. - checkpoint (bool, optional): whether to checkpoint models. Defaults to - False. - checkpoint_subpath (str, optional): subdirectory specifying where to store - checkpoints. Defaults to "~/mava/". - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - logger_config (Dict, optional): additional configuration settings for the - logger factory. Defaults to {}. - train_loop_fn (Callable, optional): function to instantiate a train loop. - Defaults to ParallelEnvironmentLoop. - eval_loop_fn (Callable, optional): function to instantiate an evaluation - loop. Defaults to ParallelEnvironmentLoop. - train_loop_fn_kwargs (Dict, optional): possible keyword arguments to send - to the training loop. Defaults to {}. - eval_loop_fn_kwargs (Dict, optional): possible keyword arguments to send to - the evaluation loop. Defaults to {}. - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - seed: seed for reproducible sampling (for epsilon greedy action selection). - evaluator_interval: An optional condition that is used to evaluate/test - system performance after [evaluator_interval] condition has been met. - If None, evaluation will happen at every timestep. - E.g. to evaluate a system after every 100 executor episodes, - evaluator_interval = {"executor_episodes": 100}. - """ - - super(DIAL, self).__init__( - environment_factory=environment_factory, - network_factory=network_factory, - logger_factory=logger_factory, - architecture=architecture, - trainer_fn=trainer_fn, - communication_module=communication_module, - executor_fn=executor_fn, - replay_stabilisation_fn=replay_stabilisation_fn, - num_executors=num_executors, - num_caches=num_caches, - environment_spec=environment_spec, - agent_net_keys=agent_net_keys, - shared_weights=shared_weights, - batch_size=batch_size, - prefetch_size=prefetch_size, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - n_step=n_step, - sequence_length=sequence_length, - period=period, - discount=discount, - optimizer=optimizer, - target_update_period=target_update_period, - executor_variable_update_period=executor_variable_update_period, - max_executor_steps=max_executor_steps, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - checkpoint_minute_interval=checkpoint_minute_interval, - logger_config=logger_config, - train_loop_fn=train_loop_fn, - eval_loop_fn=eval_loop_fn, - train_loop_fn_kwargs=train_loop_fn_kwargs, - eval_loop_fn_kwargs=eval_loop_fn_kwargs, - exploration_scheduler_fn=exploration_scheduler_fn, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - seed=seed, - ) - - if issubclass(executor_fn, executors.RecurrentExecutor): - extra_specs = self._get_extra_specs() - else: - extra_specs = {} - self._checkpoint_minute_interval = checkpoint_minute_interval - self._builder = builder.DIALBuilder( - builder.DIALConfig( - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - discount=discount, - batch_size=batch_size, - prefetch_size=prefetch_size, - target_update_period=target_update_period, - executor_variable_update_period=executor_variable_update_period, - min_replay_size=min_replay_size, - max_replay_size=max_replay_size, - samples_per_insert=samples_per_insert, - n_step=n_step, - sequence_length=sequence_length, - period=period, - max_gradient_norm=max_gradient_norm, - checkpoint=checkpoint, - optimizer=optimizer, - checkpoint_subpath=checkpoint_subpath, - checkpoint_minute_interval=checkpoint_minute_interval, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - importance_sampling_exponent=importance_sampling_exponent, - max_priority_weight=max_priority_weight, - evaluator_interval=evaluator_interval, - ), - trainer_fn=trainer_fn, - executor_fn=executor_fn, - extra_specs=extra_specs, - replay_stabilisation_fn=replay_stabilisation_fn, - ) - - def replay(self) -> Any: - """Replay data storage. - - Returns: - Any: replay data table built according the environment specification. - """ - return self._builder.make_replay_tables(self._environment_spec) - - def executor( # type: ignore[override] - self, - executor_id: str, - replay: reverb.Client, - variable_source: acme.VariableSource, - counter: counting.Counter, - trainer: Optional[training.MADQNRecurrentCommTrainer] = None, - ) -> mava.ParallelEnvironmentLoop: - """System executor - - Args: - executor_id (str): id to identify the executor process for logging purposes. - replay (reverb.Client): replay data table to push data to. - variable_source (acme.VariableSource): variable server for updating - network variables. - counter (counting.Counter): step counter object. - trainer (Optional[training.MADQNRecurrentCommTrainer], optional): - system trainer. Defaults to None. - - Returns: - mava.ParallelEnvironmentLoop: environment-executor loop instance. - """ - - return super().executor( - executor_id=executor_id, - replay=replay, - variable_source=variable_source, - counter=counter, - trainer=trainer, - ) - - def evaluator( # type: ignore[override] - self, - variable_source: acme.VariableSource, - counter: counting.Counter, - trainer: training.MADQNRecurrentCommTrainer, - ) -> Any: - """System evaluator (an executor process not connected to a dataset) - - Args: - variable_source (acme.VariableSource): variable server for updating - network variables. - counter (counting.Counter): step counter object. - trainer (Optional[training.MADQNRecurrentCommTrainer], optional): - system trainer. Defaults to None. - - Returns: - Any: environment-executor evaluation loop instance for evaluating the - performance of a system. - """ - - return super().evaluator( - variable_source=variable_source, - counter=counter, - trainer=trainer, - ) - - def build(self, name: str = "dial") -> Any: - """Build the distributed system as a graph program. - - Args: - name (str, optional): system name. Defaults to "dial". - - Returns: - Any: graph program for distributed system training. - """ - - return super().build(name=name) diff --git a/mava/systems/tf/dial/training.py b/mava/systems/tf/dial/training.py deleted file mode 100644 index 393e12848..000000000 --- a/mava/systems/tf/dial/training.py +++ /dev/null @@ -1,233 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""DIAL system trainer implementation.""" - -from typing import Any, Callable, Dict, List, Optional, Union - -import sonnet as snt -import tensorflow as tf -import tree -import trfl -from acme.tf import utils as tf2_utils -from acme.types import NestedArray -from acme.utils import counting, loggers - -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.training import MADQNRecurrentCommTrainer -from mava.utils import training_utils as train_utils - -train_utils.set_growing_gpu_memory() - - -class DIALSwitchTrainer(MADQNRecurrentCommTrainer): - """Recurrent Comm DIAL Switch trainer. - This is the trainer component of a DIAL system. IE it takes a dataset as input - and implements update functionality to learn from this dataset. - Note: this trainer is specific to switch game env. - """ - - def __init__( - self, - agents: List[str], - agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], - target_update_period: int, - dataset: tf.data.Dataset, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], - discount: float, - agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, - communication_module: BaseCommunicationModule, - max_gradient_norm: float = None, - fingerprint: bool = False, - counter: counting.Counter = None, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - ): - """Initialise DIAL trainer for switch game - - Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - communication_module (BaseCommunicationModule): module for communication - between agents. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - counter (counting.Counter, optional): step counter object. Defaults to None. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - super().__init__( - agents=agents, - agent_types=agent_types, - q_networks=q_networks, - target_q_networks=target_q_networks, - target_update_period=target_update_period, - dataset=dataset, - optimizer=optimizer, - discount=discount, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - max_gradient_norm=max_gradient_norm, - fingerprint=fingerprint, - counter=counter, - logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - communication_module=communication_module, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) - - def _forward(self, inputs: Any) -> None: - data = tree.map_structure( - lambda v: tf.expand_dims(v, axis=0) if len(v.shape) <= 1 else v, inputs.data - ) - data = tf2_utils.batch_to_sequence(data) - - observations, actions, rewards, discounts, _, _ = ( - data.observations, - data.actions, - data.rewards, - data.discounts, - data.start_of_episode, - data.extras, - ) - - # Using extra directly from inputs due to shape. - core_state = tree.map_structure( - lambda s: s[:, 0, :], inputs.data.extras["core_states"] - ) - core_message = tree.map_structure( - lambda s: s[:, 0, :], inputs.data.extras["core_messages"] - ) - T = actions[self._agents[0]].shape[0] - - # Use fact that end of episode always has the reward to - # find episode lengths. This is used to mask loss. - ep_end = tf.argmax(tf.math.abs(rewards[self._agents[0]]), axis=0) - - with tf.GradientTape(persistent=True) as tape: - q_network_losses: Dict[str, NestedArray] = { - agent: {"policy_loss": tf.zeros(())} for agent in self._agents - } - - state = {agent: core_state[agent][0] for agent in self._agents} - target_state = {agent: core_state[agent][0] for agent in self._agents} - - message = {agent: core_message[agent][0] for agent in self._agents} - target_message = {agent: core_message[agent][0] for agent in self._agents} - - # _target_q_networks must be 1 step ahead - target_channel = self._communication_module.process_messages(target_message) - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - (q_targ, m), s = self._target_q_networks[agent_key]( - observations[agent].observation[0], - target_state[agent], - target_channel[agent], - ) - target_state[agent] = s - target_message[agent] = m - - for t in range(1, T, 1): - channel = self._communication_module.process_messages(message) - target_channel = self._communication_module.process_messages( - target_message - ) - - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - - # Cast the additional discount - # to match the environment discount dtype. - - discount = tf.cast(self._discount, dtype=discounts[agent][0].dtype) - - (q_targ, m), s = self._target_q_networks[agent_key]( - observations[agent].observation[t], - target_state[agent], - target_channel[agent], - ) - - target_state[agent] = s - target_message[agent] = tf.math.multiply( - m, observations[agent].observation[t][:, :1] - ) - - (q, m), s = self._q_networks[agent_key]( - observations[agent].observation[t - 1], - state[agent], - channel[agent], - ) - - state[agent] = s - message[agent] = tf.math.multiply( - m, observations[agent].observation[t - 1][:, :1] - ) - - # Mask target - q_targ = tf.concat( - [ - [q_targ[i]] - if t <= ep_end[i] - else [tf.zeros_like(q_targ[i])] - for i in range(q_targ.shape[0]) - ], - axis=0, - ) - - loss, _ = trfl.qlearning( - q, - actions[agent][t - 1], - rewards[agent][t - 1], - discount * discounts[agent][t], - q_targ, - ) - - # Index loss (mask ended episodes) - if not tf.reduce_any(t - 1 <= ep_end): - continue - - loss = tf.reduce_mean(loss[t - 1 <= ep_end]) - # loss = tf.reduce_mean(loss) - q_network_losses[agent]["policy_loss"] += loss - - self._q_network_losses = q_network_losses - self.tape = tape diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index dce3b3f33..251d7ab6f 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -29,18 +29,18 @@ from mava import adders, core, specs, types from mava.adders import reverb as reverb_adders -from mava.systems.tf import executors, variable_utils -from mava.systems.tf.madqn import training -from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor -from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource -from mava.utils.sort_utils import sort_str_num -from mava.wrappers import NetworkStatisticsActorCritic, ScaledDetailedTrainerStatistics -from mava.utils.builder_utils import initialize_epsilon_schedulers from mava.components.tf.modules.exploration.exploration_scheduling import ( BaseExplorationScheduler, BaseExplorationTimestepScheduler, ConstantScheduler, ) +from mava.systems.tf import executors, variable_utils +from mava.systems.tf.madqn import training +from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor +from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.utils.builder_utils import initialize_epsilon_schedulers +from mava.utils.sort_utils import sort_str_num +from mava.wrappers import ScaledDetailedTrainerStatistics BoundedArray = dm_specs.BoundedArray DiscreteArray = dm_specs.DiscreteArray @@ -53,8 +53,7 @@ class MADQNConfig: Args: environment_spec: description of the action and observation spaces etc. for each agent in the system. - policy_optimizer: optimizer(s) for updating policy networks. - critic_optimizer: optimizer for updating critic networks. + optimizer: optimizer(s) for updating value networks. num_executors: number of parallel executors to use. agent_net_keys: specifies what network each agent uses. trainer_networks: networks each trainer trains on. @@ -169,7 +168,6 @@ def __init__( self._trainer_fn = trainer_fn self._executor_fn = executor_fn - def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: if type(spec) is not dict: return spec @@ -319,8 +317,6 @@ def make_adder( for table_key in self._config.table_network_config.keys() } - print() - # Select adder if issubclass(self._executor_fn, executors.FeedForwardExecutor): adder = reverb_adders.ParallelNStepTransitionAdder( @@ -343,7 +339,6 @@ def make_adder( else: raise NotImplementedError("Unknown executor type: ", self._executor_fn) - print("################3", adder) return adder def create_counter_variables( @@ -519,7 +514,6 @@ def make_trainer( target_averaging = self._config.target_averaging target_update_rate = self._config.target_update_rate - print("4") # Create variable client variables = {} set_keys = [] @@ -536,7 +530,6 @@ def make_trainer( else: get_keys.append(f"{net_key}_{net_type_key}") - print("5") variables = self.create_counter_variables(variables) count_names = [ @@ -550,7 +543,6 @@ def make_trainer( get_keys.extend(count_names) counts = {name: variables[name] for name in count_names} - print("6") variable_client = variable_utils.VariableClient( client=variable_source, variables=variables, @@ -559,11 +551,9 @@ def make_trainer( update_period=10, ) - print("7") # Get all the initial variables variable_client.get_all_and_wait() - print("8") # Convert network keys for the trainer. trainer_agents = self._agents[: len(trainer_table_entry)] trainer_agent_net_keys = { diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 4ed085149..8bbec9183 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -30,18 +30,18 @@ from dm_env import specs from mava import adders -from mava import core -from mava.systems.tf import executors -from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num from mava.components.tf.modules.exploration.exploration_scheduling import ( BaseExplorationTimestepScheduler, ) +from mava.systems.tf import executors +from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num Array = specs.Array BoundedArray = specs.BoundedArray DiscreteArray = specs.DiscreteArray tfd = tfp.distributions + class DQNExecutor: def __init__(self, action_selectors: Dict): self._action_selectors = action_selectors @@ -143,15 +143,16 @@ def __init__( self._observation_networks = observation_networks self._action_selectors = action_selectors self._value_networks = value_networks - self._agent_net_keys=agent_net_keys - self._adder=adder - self._variable_client=variable_client + self._agent_net_keys = agent_net_keys + self._adder = adder + self._variable_client = variable_client @tf.function def _policy( - self, agent: str, + self, + agent: str, observation: types.NestedTensor, - legal_actions: types.NestedTensor + legal_actions: types.NestedTensor, ) -> types.NestedTensor: """Agent specific policy function @@ -206,10 +207,8 @@ def select_action( return action - def select_actions( - self, observations: Dict[str, types.NestedArray] - ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: - """select the actions for all agents in the system + def select_actions(self, observations: Dict[str, types.NestedArray]) -> Dict: + """Select the actions for all agents in the system Args: observations: agent observations from the @@ -249,8 +248,7 @@ def observe_first( ) extras["network_int_keys"] = self._network_int_keys_extras - - + self._adder.add_first(timestep, extras) def observe( @@ -290,7 +288,7 @@ class MADQNRecurrentExecutor(executors.RecurrentExecutor, DQNExecutor): def __init__( self, - observation_networks :Dict[str, snt.Module], + observation_networks: Dict[str, snt.Module], action_selectors: Dict[str, snt.Module], value_networks: Dict[str, snt.Module], agent_specs: Dict[str, EnvironmentSpec], @@ -335,14 +333,13 @@ def __init__( self._evaluator = evaluator self._interval = interval self._value_networks = value_networks - self._agent_net_keys=agent_net_keys - self._adder=adder - self._variable_client=variable_client - self._store_recurrent_state=store_recurrent_state + self._agent_net_keys = agent_net_keys + self._adder = adder + self._variable_client = variable_client + self._store_recurrent_state = store_recurrent_state self._observation_networks = observation_networks self._action_selectors = action_selectors self._states: Dict[str, Any] = {} - @tf.function def _policy( @@ -351,7 +348,7 @@ def _policy( observation: types.NestedTensor, legal_actions: types.NestedTensor, state: types.NestedTensor, - ) -> Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: + ) -> Tuple: """Agent specific policy function Args: agent: agent id @@ -379,7 +376,7 @@ def _policy( # Pass action values through action selector action = self._action_selectors[agent](action_values, batched_legal_actions) - + return action, new_state def select_action( @@ -396,7 +393,10 @@ def select_action( # Step the recurrent policy forward given the current observation and state. action, new_state = self._policy( - agent, observation.observation, observation.legal_actions, self._states[agent] + agent, + observation.observation, + observation.legal_actions, + self._states[agent], ) # Bookkeeping of recurrent states for the observe method. @@ -407,9 +407,7 @@ def select_action( return action - def select_actions( - self, observations: Dict[str, types.NestedArray] - ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + def select_actions(self, observations: Dict[str, types.NestedArray]) -> Any: """select the actions for all agents in the system Args: observations: agent observations from the @@ -460,10 +458,7 @@ def observe_first( } extras.update( - { - "core_states": numpy_states, - "zero_padding_mask": np.array(1) - } + {"core_states": numpy_states, "zero_padding_mask": np.array(1)} ) extras["network_int_keys"] = self._network_int_keys_extras @@ -493,12 +488,9 @@ def observe( } next_extras.update( - { - "core_states": numpy_states, - "zero_padding_mask": np.array(1) - } + {"core_states": numpy_states, "zero_padding_mask": np.array(1)} ) - + next_extras["network_int_keys"] = self._network_int_keys_extras self._adder.add(actions, next_timestep, next_extras) # type: ignore diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index f2922c397..bf2fca57e 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -14,17 +14,17 @@ # limitations under the License. from typing import Dict, Mapping, Optional, Sequence, Union -import numpy as np import sonnet as snt import tensorflow as tf from acme import types from acme.tf import utils as tf2_utils +from acme.tf.networks import AtariTorso from dm_env import specs from mava import specs as mava_specs from mava.components.tf import networks -from mava.utils.enums import ArchitectureType from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy +from mava.utils.enums import ArchitectureType Array = specs.Array BoundedArray = specs.BoundedArray @@ -36,6 +36,7 @@ def make_default_networks( agent_net_keys: Dict[str, str], value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, architecture_type: ArchitectureType = ArchitectureType.feedforward, + atari_torso_observation_network: bool = False, seed: Optional[int] = None, ) -> Mapping[str, types.TensorTransformation]: """Default networks for maddpg. @@ -85,7 +86,6 @@ def make_default_networks( # Create agent_type specs specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} - if isinstance(value_networks_layer_sizes, Sequence): value_networks_layer_sizes = { key: value_networks_layer_sizes for key in specs.keys() @@ -102,7 +102,10 @@ def make_default_networks( num_actions = spec.actions.num_values # An optional network to process observations - observation_network = tf2_utils.to_sonnet_module(tf.identity) + if not atari_torso_observation_network: + observation_network = tf2_utils.to_sonnet_module(tf.identity) + else: + observation_network = AtariTorso() # Create the policy network. if architecture_type == ArchitectureType.feedforward: diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 4f0d38ecf..3042f76a7 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -16,29 +16,25 @@ """MADQN system implementation.""" import functools -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, Mapping +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union -import numpy as np import acme import dm_env import launchpad as lp +import numpy as np import reverb import sonnet as snt from acme import specs as acme_specs from acme.tf import utils as tf2_utils -from acme.utils import loggers from dm_env import specs import mava from mava import core from mava import specs as mava_specs -from mava.components.tf.architectures import ( - DecentralisedValueActor, -) +from mava.components.tf.architectures import DecentralisedValueActor from mava.components.tf.modules.exploration.exploration_scheduling import ( ConstantScheduler, ) -from mava.types import EpsilonScheduler from mava.environment_loop import ParallelEnvironmentLoop from mava.systems.tf import executors from mava.systems.tf.madqn import builder, training @@ -47,6 +43,7 @@ sample_new_agent_keys, ) from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.types import EpsilonScheduler from mava.utils import enums from mava.utils.loggers import MavaLogger, logger_utils from mava.utils.sort_utils import sort_str_num @@ -69,9 +66,7 @@ def __init__( # noqa Mapping[str, Mapping[str, EpsilonScheduler]], ], logger_factory: Callable[[str], MavaLogger] = None, - architecture: Type[ - DecentralisedValueActor - ] = DecentralisedValueActor, + architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, trainer_fn: Type[training.MADQNTrainer] = training.MADQNTrainer, executor_fn: Type[core.Executor] = MADQNFeedForwardExecutor, num_executors: int = 1, @@ -91,14 +86,14 @@ def __init__( # noqa target_update_rate: Optional[float] = None, executor_variable_update_period: int = 1000, min_replay_size: int = 1000, - max_replay_size: int = 1000000, + max_replay_size: int = 100000, samples_per_insert: Optional[float] = 32.0, - optimizer: Union[ - snt.Optimizer, Dict[str, snt.Optimizer] - ] = snt.optimizers.Adam(learning_rate=1e-4), + optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( + learning_rate=1e-4 + ), n_step: int = 5, sequence_length: int = 20, - period: int = 20, + period: int = 10, max_gradient_norm: float = None, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", @@ -113,10 +108,13 @@ def __init__( # noqa learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the system + Args: environment_factory: function to instantiate an environment. network_factory: function to instantiate system networks. + exploration_scheduler_fn: function to schedule + exploration. e.g. epsilon greedy logger_factory: function to instantiate a system logger. architecture: @@ -126,9 +124,7 @@ def __init__( # noqa executor_fn: executor type, e.g. feedforward or recurrent. num_executors: number of executor processes to run in - parallel.. - environment_spec: description of - the action, observation spaces etc. for each agent in the system. + parallel. trainer_networks: networks each trainer trains on. network_sampling_setup: List of networks that are randomly @@ -147,6 +143,8 @@ def __init__( # noqa shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. + environment_spec: description of + the action, observation spaces etc. for each agent in the system. discount: discount factor to use for TD updates. batch_size: sample batch size for updates. prefetch_size: size to prefetch from replay. @@ -162,9 +160,7 @@ def __init__( # noqa max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. - policy_optimizer: optimizer(s) for updating policy networks. - critic_optimizer: optimizer for updating critic - networks. + optimizers: optimizer(s) for updating value networks. n_step: number of steps to include prior to boostrapping. sequence_length: recurrent sequence rollout length. period: Consecutive starting points for overlapping @@ -172,10 +168,10 @@ def __init__( # noqa max_gradient_norm: maximum allowed norm for gradients before clipping is applied. checkpoint: whether to checkpoint models. - checkpoint_minute_interval: The number of minutes to wait between - checkpoints. checkpoint_subpath: subdirectory specifying where to store checkpoints. + checkpoint_minute_interval: The number of minutes to wait between + checkpoints. logger_config: additional configuration settings for the logger factory. train_loop_fn: function to instantiate a train loop. @@ -191,6 +187,12 @@ def __init__( # noqa values for trainer_steps, trainer_walltime, evaluator_steps, evaluator_episodes, executor_episodes or executor_steps. E.g. termination_condition = {'trainer_steps': 100000}. + evaluator_interval: An optional condition that is used to + evaluate/test system performance after [evaluator_interval] + condition has been met. If None, evaluation will + happen at every timestep. + E.g. to evaluate a system after every 100 executor episodes, + evaluator_interval = {"executor_episodes": 100}. learning_rate_scheduler_fn: dict with two functions/classes (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate, @@ -198,12 +200,6 @@ def __init__( # noqa See examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py for an example. - evaluator_interval: An optional condition that is used to - evaluate/test system performance after [evaluator_interval] - condition has been met. If None, evaluation will - happen at every timestep. - E.g. to evaluate a system after every 100 executor episodes, - evaluator_interval = {"executor_episodes": 100}. """ if not environment_spec: @@ -268,7 +264,6 @@ def __init__( # noqa self._network_sampling_setup, # type: ignore ) - # Check that the environment and agent_net_keys has the same amount of agents sample_length = len(self._network_sampling_setup[0]) # type: ignore assert len(environment_spec.get_agent_ids()) == len(self._agent_net_keys.keys()) @@ -386,7 +381,6 @@ def __init__( # noqa net_spec = {"network_keys": {agent: int_spec for agent in agents}} extra_specs.update(net_spec) - self._builder = builder.MADQNBuilder( builder.MADQNConfig( environment_spec=environment_spec, @@ -425,9 +419,10 @@ def __init__( # noqa ) def _get_extra_specs(self) -> Any: - """helper to establish specs for extra information + """Helper to establish specs for extra information + Returns: - dictionary containing extra specs + Dictionary containing extra specs """ agents = self._environment_spec.get_agent_ids() @@ -447,8 +442,10 @@ def _get_extra_specs(self) -> Any: def replay(self) -> Any: """Step counter + Args: checkpoint: whether to checkpoint the counter. + Returns: step counter object. """ @@ -464,7 +461,6 @@ def create_system( agent_net_keys=self._agent_net_keys, ) - # architecture args architecture_config = { "environment_spec": self._environment_spec, @@ -542,13 +538,14 @@ def executor( def evaluator( self, variable_source: acme.VariableSource, - logger: loggers.Logger = None, ) -> Any: """System evaluator (an executor process not connected to a dataset) + Args: variable_source: variable server for updating network variables. logger: logger object. + Returns: environment-executor evaluation loop instance for evaluating the performance of a system. @@ -598,11 +595,13 @@ def trainer( variable_source: MavaVariableSource, ) -> mava.core.Trainer: """System trainer + Args: trainer_id: Id of the trainer being created. replay: replay data table to pull data from. variable_source: variable server for updating network variables. + Returns: system trainer. """ @@ -628,10 +627,12 @@ def trainer( variable_source=variable_source, ) - def build(self, name: str = "maddpg") -> Any: + def build(self, name: str = "madqn") -> Any: """Build the distributed system as a graph program. + Args: name: system name. + Returns: graph program for distributed system training. """ diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index ea8b6b38c..5c2f9a831 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -17,7 +17,6 @@ """MADQN trainer implementation.""" import copy -import time from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -26,15 +25,11 @@ import tensorflow as tf import tree import trfl -from acme.tf import losses from acme.tf import utils as tf2_utils from acme.utils import loggers import mava from mava import types as mava_types -from mava.adders.reverb.base import Trajectory -from mava.components.tf.losses.sequence import recurrent_n_step_critic_loss -from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor from mava.systems.tf.variable_utils import VariableClient from mava.utils import training_utils as train_utils from mava.utils.sort_utils import sort_str_num @@ -69,20 +64,14 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise MADDPG trainer + """Initialise MADQN trainer Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". - policy_networks: policy networks for each agent in + value_networks: value networks for each agents in the system. - critic_networks: critic network(s), shared or for - each agent in the system. - target_policy_networks: target policy networks. - target_critic_networks: target critic networks. - policy_optimizer: - optimizer(s) for updating policy networks. - critic_optimizer: - optimizer for updating critic networks. + target_value_networks: target value networks. + optimizer: optimizer(s) for updating policy networks. discount: discount factor for TD updates. target_averaging: whether to use polyak averaging for target network updates. @@ -182,8 +171,10 @@ def __init__( self._timestamp: Optional[float] = None def _update_target_networks(self) -> None: + """Update the target networks using either target averaging or by directy copying the weights of the online networks every few steps.""" + for key in self.unique_net_keys: # Update target network. online_variables = ( @@ -208,19 +199,22 @@ def _update_target_networks(self) -> None: self._num_steps.assign_add(1) def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: + """Depricated""" + pass def _transform_observations( self, obs: Dict[str, mava_types.OLT], next_obs: Dict[str, mava_types.OLT] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: - """Transform the observatations using the observation networks of each agent." + + """Transform the observations using the observation networks of each agent." Args: obs: observations at timestep t-1 next_obs: observations at timestep t Returns: - Transformed observatations + Transformed observations """ o_tm1 = {} o_t = {} @@ -249,22 +243,15 @@ def _step( # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) - losses = self._forward_backward(sample) - - # Log losses per agent - return train_utils.map_losses_per_agent_value(losses) - - @tf.function - def _forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: - - self._forward(inputs) + self._forward(sample) self._backward() # Update the target networks self._update_target_networks() - return self.value_losses + # Log losses per agent + return train_utils.map_losses_per_agent_value(self.value_losses) # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: @@ -285,7 +272,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # o_t = dictionary of next observations or next observation sequences # e_t [Optional] = extra data for timestep t that the agents persist in replay. trans = mava_types.Transition(*inputs.data) - o_tm1, o_t, a_tm1, r_t, d_t, e_tm1, e_t = ( + o_tm1, o_t, a_tm1, r_t, d_t, _, _ = ( trans.observations, trans.next_observations, trans.actions, @@ -308,8 +295,10 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: q_t_value = self._target_value_networks[agent_key](o_t_trans[agent]) q_t_selector = self._value_networks[agent_key](o_t_trans[agent]) - # TODO Legal action masking - # q_t_selector = tf.where(o_t[agent].legal_actions, q_t_selector, -999999999) + # Legal action masking + q_t_selector = tf.where( + tf.cast(o_t[agent].legal_actions, "bool"), q_t_selector, -999999999 + ) # pcont discount = tf.cast(self._discount, dtype=d_t[agent].dtype) @@ -360,7 +349,7 @@ def _backward(self) -> None: train_utils.safe_del(self, "tape") def step(self) -> None: - """trainer step to update the parameters of the agents in the system""" + """Trainer step to update the parameters of the agents in the system""" raise NotImplementedError("A trainer statistics wrapper should overwrite this.") @@ -379,6 +368,7 @@ def after_trainer_step(self) -> None: def _decay_lr(self, trainer_step: int) -> None: """Decay lr. + Args: trainer_step : trainer step time t. """ @@ -390,8 +380,8 @@ def _decay_lr(self, trainer_step: int) -> None: class MADQNRecurrentTrainer: """Recurrent MADQN trainer. - This is the trainer component of a MADQN system. IE it takes a dataset as input - and implements update functionality to learn from this dataset. + This is the trainer component of a recurrent MADQN system. IE it takes a dataset + as input and implements update functionality to learn from this dataset. """ def __init__( @@ -415,20 +405,15 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise Recurrent MADDPG trainer + """Initialise Recurrent MADQN trainer + Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". - policy_networks: policy networks for each agent in + value_networks: value networks for each agent in the system. - critic_networks: critic network(s), shared or for - each agent in the system. - target_policy_networks: target policy networks. - target_critic_networks: target critic networks. - policy_optimizer: - optimizer(s) for updating policy networks. - critic_optimizer: - optimizer for updating critic networks. + target_value_networks: target value networks. + optimizer: optimizer(s) for updating value networks. discount: discount factor for TD updates. target_averaging: whether to use polyak averaging for target network updates. @@ -527,20 +512,21 @@ def __init__( self._timestamp: Optional[float] = None def step(self) -> None: - """trainer step to update the parameters of the agents in the system""" + """Trainer step to update the parameters of the agents in the system""" raise NotImplementedError("A trainer statistics wrapper should overwrite this.") def _transform_observations( self, observations: Dict[str, mava_types.OLT] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: - """apply the observation networks to the raw observations from the dataset + """Apply the observation networks to the raw observations from the dataset + Args: - obs: raw agent observations - next_obs: raw next observations + observations: raw agent observations + Returns: - transformed - observations (features) + obs_trans: transformed agent observation + obs_target_trans: transformed target network observations """ # Note (dries): We are assuming that only the policy network @@ -573,8 +559,11 @@ def _transform_observations( return obs_trans, obs_target_trans def _update_target_networks(self) -> None: - """Update the target networks using either target averaging or - by directy copying the weights of the online networks every few steps.""" + """Update the target networks. + + Using either target averaging or + by directy copying the weights of the online networks every few steps. + """ for key in self.unique_net_keys: # Update target network. online_variables = ( @@ -630,7 +619,8 @@ def _step( # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: - """Trainer forward pass + """Trainer forward pass. + Args: inputs: input data from the data table (transitions) """ @@ -669,7 +659,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # is recurrent and not the observation network. obs_trans, target_obs_trans = self._transform_observations(observations) - for agent in self._agents: + for agent in self._trainer_agent_list: agent_key = self._agent_net_keys[agent] # Double Q-learning @@ -694,10 +684,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: -999999999, ) - # Cast the additional discount to match - # the environment discount dtype. - discount = tf.cast(self._discount, dtype=discounts[agent].dtype) - # Flatten out time and batch dim q_tm1, _ = train_utils.combine_dim(q_tm1) q_t_selector, _ = train_utils.combine_dim(q_t_selector) @@ -712,15 +698,23 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: discounts[agent][:-1] # Chop off last timestep ) + # Cast the additional discount to match + # the environment discount dtype. + discount = tf.cast(self._discount, dtype=discounts[agent].dtype) + # Value loss value_loss, _ = trfl.double_qlearning( q_tm1, a_tm1, r_t, discount * d_t, q_t_value, q_t_selector ) # Zero-padding mask - zero_padding_mask = tf.cast(extras["zero_padding_mask"], dtype=value_loss.dtype)[:-1] + zero_padding_mask, _ = train_utils.combine_dim( + tf.cast(extras["zero_padding_mask"], dtype=value_loss.dtype)[:-1] + ) masked_loss = value_loss * zero_padding_mask - self.value_losses[agent] = tf.reduce_sum(masked_loss) / tf.reduce_sum(zero_padding_mask) + self.value_losses[agent] = tf.reduce_sum(masked_loss) / tf.reduce_sum( + zero_padding_mask + ) self.tape = tape @@ -755,11 +749,6 @@ def _backward(self) -> None: train_utils.safe_del(self, "tape") - def step(self) -> None: - """trainer step to update the parameters of the agents in the system""" - - raise NotImplementedError("A trainer statistics wrapper should overwrite this.") - def after_trainer_step(self) -> None: """Optionally decay lr after every training step.""" if self._learning_rate_scheduler_fn: @@ -775,6 +764,7 @@ def after_trainer_step(self) -> None: def _decay_lr(self, trainer_step: int) -> None: """Decay lr. + Args: trainer_step : trainer step time t. """ diff --git a/mava/systems/tf/value_decomposition/networks.py b/mava/systems/tf/value_decomposition/networks.py index fb661d5fd..d2f11602f 100644 --- a/mava/systems/tf/value_decomposition/networks.py +++ b/mava/systems/tf/value_decomposition/networks.py @@ -14,7 +14,6 @@ # limitations under the License. from typing import Dict, Mapping, Optional, Sequence, Union -import numpy as np import sonnet as snt import tensorflow as tf from acme import types @@ -23,7 +22,6 @@ from mava import specs as mava_specs from mava.components.tf import networks -from mava.utils.enums import ArchitectureType from mava.components.tf.networks.epsilon_greedy import EpsilonGreedy Array = specs.Array @@ -37,29 +35,20 @@ def make_default_networks( value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, seed: Optional[int] = None, ) -> Mapping[str, types.TensorTransformation]: - """Default networks for maddpg. + + """Default networks for Value Decomposition systems. Args: environment_spec: description of the action and observation spaces etc. for each agent in the system. agent_net_keys: specifies what network each agent uses. - vmin: hyperparameters for the distributional critic in mad4pg. - vmax: hyperparameters for the distributional critic in mad4pg. - net_spec_keys: specifies the specs of each network. - policy_networks_layer_sizes: size of policy networks. - critic_networks_layer_sizes: size of critic networks. - sigma: hyperparameters used to add Gaussian noise - for simple exploration. Defaults to 0.3. - archecture_type: archecture used - for agent networks. Can be feedforward or recurrent. - Defaults to ArchitectureType.feedforward. - - num_atoms: hyperparameters for the distributional critic in - mad4pg. + value_networks_layer_sizes: size of value networks seed: random seed for network initialization. Returns: - returned agent networks. + Agents value networks + Agents action selectors + Agents observation networks """ if not value_networks_layer_sizes: @@ -75,7 +64,6 @@ def make_default_networks( # Create agent_type specs specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} - if isinstance(value_networks_layer_sizes, Sequence): value_networks_layer_sizes = { key: value_networks_layer_sizes for key in specs.keys() diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 2f79e483e..c8321f091 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -15,7 +15,7 @@ """Value Decomposition system implementation.""" -from typing import Callable, Dict, List, Optional, Type, Union, Mapping +from typing import Callable, Dict, Mapping, Optional, Type, Union import dm_env import reverb @@ -24,24 +24,23 @@ import mava from mava import specs as mava_specs -from mava.components.tf.architectures import ( - DecentralisedValueActor, -) -from mava.types import EpsilonScheduler +from mava.components.tf.architectures import DecentralisedValueActor +from mava.components.tf.modules.mixing.mixers import QMIX, VDN from mava.environment_loop import ParallelEnvironmentLoop +from mava.systems.tf.madqn import MADQN from mava.systems.tf.madqn.execution import MADQNRecurrentExecutor -from mava.systems.tf.value_decomposition.training import ValueDecompositionRecurrentTrainer +from mava.systems.tf.value_decomposition.training import ( + ValueDecompositionRecurrentTrainer, +) from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.types import EpsilonScheduler from mava.utils import enums from mava.utils.loggers import MavaLogger -from mava.systems.tf.madqn import MADQN -from mava.components.tf.modules.mixing.mixers import QMIX, VDN class ValueDecomposition(MADQN): """Value Decomposition systems.""" - def __init__( self, environment_factory: Callable[[bool], dm_env.Environment], @@ -53,10 +52,10 @@ def __init__( Mapping[str, Mapping[str, EpsilonScheduler]], ], logger_factory: Callable[[str], MavaLogger] = None, - architecture: Type[ - DecentralisedValueActor - ] = DecentralisedValueActor, - trainer_fn: Type[ValueDecompositionRecurrentTrainer] = ValueDecompositionRecurrentTrainer, + architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, + trainer_fn: Type[ + ValueDecompositionRecurrentTrainer + ] = ValueDecompositionRecurrentTrainer, executor_fn: Type[MADQNRecurrentExecutor] = MADQNRecurrentExecutor, num_executors: int = 1, shared_weights: bool = True, @@ -71,9 +70,9 @@ def __init__( min_replay_size: int = 100, max_replay_size: int = 5000, samples_per_insert: Optional[float] = 2.0, - optimizer: Union[ - snt.Optimizer, Dict[str, snt.Optimizer] - ] = snt.optimizers.Adam(learning_rate=1e-4), + optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( + learning_rate=1e-4 + ), mixer_optimizer: snt.Optimizer = snt.optimizers.Adam(learning_rate=1e-4), sequence_length: int = 20, period: int = 10, @@ -95,36 +94,24 @@ def __init__( environment_factory: function to instantiate an environment. network_factory: function to instantiate system networks. + mixer: mixing network + exploration_scheduler_fn: function to schedule + exploration. e.g. epsilon greedy logger_factory: function to instantiate a system logger. - architecture: - system architecture, e.g. decentralised or centralised. + architecture: system architecture, + e.g. decentralised or centralised. trainer_fn: training type associated with executor and architecture, e.g. centralised training. executor_fn: executor type, e.g. feedforward or recurrent. num_executors: number of executor processes to run in - parallel.. - environment_spec: description of - the action, observation spaces etc. for each agent in the system. - trainer_networks: networks each - trainer trains on. - network_sampling_setup: List of networks that are randomly - sampled from by the executors at the start of an environment run. - enums.NetworkSampler settings: - fixed_agent_networks: Keeps the networks - used by each agent fixed throughout training. - random_agent_networks: Creates N network policies, where N is the - number of agents. Randomly select policies from this sets for each - agent at the start of a episode. This sampling is done with - replacement so the same policy can be selected for more than one - agent for a given episode. - Custom list: Alternatively one can specify a custom nested list, - with network keys in, that will be used by the executors at - the start of each episode to sample networks for each agent. + parallel. shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. + environment_spec: description of + the action, observation spaces etc. for each agent in the system. discount: discount factor to use for TD updates. batch_size: sample batch size for updates. prefetch_size: size to prefetch from replay. @@ -140,20 +127,18 @@ def __init__( max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. - policy_optimizer: optimizer(s) for updating policy networks. - critic_optimizer: optimizer for updating critic - networks. - n_step: number of steps to include prior to boostrapping. + optimizer: optimizer(s) for updating value networks. + mixer_optimizer: optimizer for updating mixing networks. sequence_length: recurrent sequence rollout length. period: Consecutive starting points for overlapping rollouts across a sequence. max_gradient_norm: maximum allowed norm for gradients before clipping is applied. checkpoint: whether to checkpoint models. - checkpoint_minute_interval: The number of minutes to wait between - checkpoints. checkpoint_subpath: subdirectory specifying where to store checkpoints. + checkpoint_minute_interval: The number of minutes to wait between + checkpoints. logger_config: additional configuration settings for the logger factory. train_loop_fn: function to instantiate a train loop. @@ -169,6 +154,12 @@ def __init__( values for trainer_steps, trainer_walltime, evaluator_steps, evaluator_episodes, executor_episodes or executor_steps. E.g. termination_condition = {'trainer_steps': 100000}. + evaluator_interval: An optional condition that is used to + evaluate/test system performance after [evaluator_interval] + condition has been met. If None, evaluation will + happen at every timestep. + E.g. to evaluate a system after every 100 executor episodes, + evaluator_interval = {"executor_episodes": 100}. learning_rate_scheduler_fn: dict with two functions/classes (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate, @@ -176,12 +167,6 @@ def __init__( See examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py for an example. - evaluator_interval: An optional condition that is used to - evaluate/test system performance after [evaluator_interval] - condition has been met. If None, evaluation will - happen at every timestep. - E.g. to evaluate a system after every 100 executor episodes, - evaluator_interval = {"executor_episodes": 100}. """ super().__init__( environment_factory=environment_factory, @@ -221,7 +206,7 @@ def __init__( termination_condition=termination_condition, evaluator_interval=evaluator_interval, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) + ) if isinstance(mixer, str): if mixer == "qmix": @@ -265,9 +250,10 @@ def trainer( # Create the system networks = self.create_system() + # Create the dataset dataset = self._builder.make_dataset_iterator(replay, trainer_id) - trainer = self._builder.make_trainer( + trainer: ValueDecompositionRecurrentTrainer = self._builder.make_trainer( networks=networks, trainer_networks=self._trainer_networks[trainer_id], trainer_table_entry=self._table_network_config[trainer_id], @@ -278,4 +264,4 @@ def trainer( trainer.setup_mixer(self._mixer, self._mixer_optimizer) - return trainer \ No newline at end of file + return trainer diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index 186911574..8a957ed6b 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -17,35 +17,28 @@ """Value Decomposition trainer implementation.""" import copy -import time -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import numpy as np import reverb import sonnet as snt import tensorflow as tf import tree import trfl -from acme.tf import losses from acme.tf import utils as tf2_utils from acme.utils import loggers -import mava -from mava import types as mava_types -from mava.adders.reverb.base import Trajectory -from mava.components.tf.losses.sequence import recurrent_n_step_critic_loss -from mava.systems.tf.madqn.execution import MADQNFeedForwardExecutor from mava.systems.tf.madqn.training import MADQNRecurrentTrainer from mava.systems.tf.variable_utils import VariableClient from mava.utils import training_utils as train_utils -from mava.utils.sort_utils import sort_str_num train_utils.set_growing_gpu_memory() class ValueDecompositionRecurrentTrainer(MADQNRecurrentTrainer): - """MADQN trainer. - This is the trainer component of a MADDPG system. IE it takes a dataset as input + """Value Decomposition Trainer. + + This is the trainer component of a Value Decomposition system. + IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -70,20 +63,14 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise MADDPG trainer + """Initialise Value Decompostion trainer Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". - policy_networks: policy networks for each agent in + value_networks: value networks for each agent in the system. - critic_networks: critic network(s), shared or for - each agent in the system. - target_policy_networks: target policy networks. - target_critic_networks: target critic networks. - policy_optimizer: - optimizer(s) for updating policy networks. - critic_optimizer: - optimizer for updating critic networks. + target_value_networks: target value networks. + optimizer: optimizer(s) for updating value networks. discount: discount factor for TD updates. target_averaging: whether to use polyak averaging for target network updates. @@ -132,7 +119,13 @@ def __init__( self._target_mixer = None self._mixer_optimizer = None - def setup_mixer(self, mixer: snt.Module, mixer_optimizer: snt.Module): + def setup_mixer(self, mixer: snt.Module, mixer_optimizer: snt.Module) -> None: + """Initialize the mixer network + + Args: + mixer: mixer network + mixer_optimizer: optimizer for updating mixing networks. + """ self._mixer = mixer self._target_mixer = copy.deepcopy(mixer) self._mixer_optimizer = mixer_optimizer @@ -145,21 +138,22 @@ def _update_target_networks(self) -> None: target_variables = [] for key in self.unique_net_keys: # Update target network. - online_variables += list(( - *self._observation_networks[key].variables, - *self._value_networks[key].variables, - )) - target_variables += list(( - *self._target_observation_networks[key].variables, - *self._target_value_networks[key].variables, - )) + online_variables += list( + ( + *self._observation_networks[key].variables, + *self._value_networks[key].variables, + ) + ) + target_variables += list( + ( + *self._target_observation_networks[key].variables, + *self._target_value_networks[key].variables, + ) + ) # Add mixer variables - online_variables += list(( - *self._mixer.variables, - )) - target_variables += list(( - *self._target_mixer.variables, - )) + if self._mixer is not None: + online_variables += list((*self._mixer.variables,)) + target_variables += list((*self._target_mixer.variables,)) if self._target_averaging: assert 0.0 < self._target_update_rate < 1.0 @@ -171,7 +165,7 @@ def _update_target_networks(self) -> None: if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) - + self._num_steps.assign_add(1) # Forward pass that calculates loss. @@ -205,11 +199,13 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Get initial state for the LSTM from replay and # extract the first state in the sequence. core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) - target_core_state = tree.map_structure(lambda s: s[0, :, :], extras["core_states"]) + target_core_state = tree.map_structure( + lambda s: s[0, :, :], extras["core_states"] + ) # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: - # NOTE (Dries): We are assuming that only the valu network + # NOTE (Dries): We are assuming that only the value network # is recurrent and not the observation network. obs_trans, target_obs_trans = self._transform_observations(observations) @@ -223,33 +219,28 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Double Q-learning q_tm1_values, _ = snt.static_unroll( - self._value_networks[agent_key], obs_trans[agent], core_state[agent][0] + self._value_networks[agent_key], + obs_trans[agent], + core_state[agent][0], ) # Q-value of the action taken by agent - chosen_action_q_value = trfl.batched_index( - q_tm1_values, actions[agent] - ) - + chosen_action_q_value = trfl.batched_index(q_tm1_values, actions[agent]) # Q-value of the next state - q_t_selector = tf.where( - tf.cast(observations[agent].legal_actions, 'bool'), - q_tm1_values, -999999999 + q_t_selector = tf.where( + tf.cast(observations[agent].legal_actions, "bool"), + q_tm1_values, + -999999999, ) q_t_values, _ = snt.static_unroll( - self._target_value_networks[agent_key], - target_obs_trans[agent], - target_core_state[agent][0] + self._target_value_networks[agent_key], + target_obs_trans[agent], + target_core_state[agent][0], ) max_action = tf.argmax(q_t_selector, axis=-1) - max_action_q_value = trfl.batched_index( - q_t_values, - max_action - ) - + max_action_q_value = trfl.batched_index(q_t_values, max_action) # Append agent values to lists - # NOTE (Claude) appending to a list does not work in tf.function chosen_action_q_value_all_agents.append(chosen_action_q_value) max_action_q_value_all_agents.append(max_action_q_value) reward_all_agents.append(rewards[agent]) @@ -258,22 +249,23 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Stack list of tensors into tensor with trailing agent dim chosen_action_q_value_all_agents = tf.stack( chosen_action_q_value_all_agents, axis=-1 - ) # shape=(T,B, Num_Agents) + ) # shape=(T,B, Num_Agents) max_action_q_value_all_agents = tf.stack( max_action_q_value_all_agents, axis=-1 - ) # shape=(T,B, Num_Agents) + ) # shape=(T,B, Num_Agents) reward_all_agents = tf.stack(reward_all_agents, axis=-1) env_discount_all_agents = tf.stack(env_discount_all_agents, axis=-1) # Mixing - chosen_action_q_value_all_agents = self._mixer( - chosen_action_q_value_all_agents, - states=global_env_state, - ) - max_action_q_value_all_agents = self._target_mixer( - max_action_q_value_all_agents, - states=global_env_state - ) + if self._mixer is not None: + chosen_action_q_value_all_agents = self._mixer( + chosen_action_q_value_all_agents, + states=global_env_state, + ) + max_action_q_value_all_agents = self._target_mixer( + max_action_q_value_all_agents, states=global_env_state + ) + # NOTE Team reward is just the mean over agents indevidual rewards reward_all_agents = tf.reduce_mean( reward_all_agents, axis=-1, keepdims=True @@ -301,7 +293,9 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: value_loss = 0.5 * tf.square(td_error) # Zero-padding mask - zero_padding_mask = tf.cast(extras["zero_padding_mask"], dtype=value_loss.dtype)[:-1] + zero_padding_mask = tf.cast( + extras["zero_padding_mask"], dtype=value_loss.dtype + )[:-1] masked_loss = value_loss * tf.expand_dims(zero_padding_mask, axis=-1) masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(zero_padding_mask) @@ -331,15 +325,14 @@ def _backward(self) -> None: gradients = tape.gradient(value_losses[agent], variables) # Maybe clip gradients. - gradients = tf.clip_by_global_norm( - gradients, self._max_gradient_norm - )[0] + gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] # Apply gradients. self._optimizers[agent_key].apply(gradients, variables) # Mixer - mixer_variables = self._mixer.trainable_variables + if self._mixer is not None: + mixer_variables: Sequence[tf.Variable] = self._mixer.trainable_variables gradients = tape.gradient(mixer_loss, mixer_variables) @@ -347,7 +340,7 @@ def _backward(self) -> None: gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] # Apply gradients. - if mixer_variables: + if mixer_variables and self._mixer_optimizer is not None: self._mixer_optimizer.apply(gradients, mixer_variables) - train_utils.safe_del(self, "tape") \ No newline at end of file + train_utils.safe_del(self, "tape") diff --git a/mava/utils/environments/flatland_utils.py b/mava/utils/environments/flatland_utils.py index b66d77714..50ca59b5f 100644 --- a/mava/utils/environments/flatland_utils.py +++ b/mava/utils/environments/flatland_utils.py @@ -15,8 +15,11 @@ from typing import Optional +from mava.wrappers.env_preprocess_wrappers import ( + ConcatAgentIdToObservation, + ConcatPrevActionToObservation, +) from mava.wrappers.flatland import FlatlandEnvWrapper -from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation try: from flatland.envs.line_generators import sparse_line_generator @@ -32,6 +35,7 @@ except ModuleNotFoundError: pass + def _create_rail_env_with_tree_obs( n_agents: int = 5, x_dim: int = 30, @@ -39,7 +43,7 @@ def _create_rail_env_with_tree_obs( n_cities: int = 2, max_rails_between_cities: int = 2, max_rails_in_city: int = 3, - seed: int = 0, + seed: Optional[int] = 0, malfunction_rate: float = 1 / 200, malfunction_min_duration: int = 20, malfunction_max_duration: int = 50, @@ -99,20 +103,20 @@ def _create_rail_env_with_tree_obs( def make_environment( - n_agents: int =10, + n_agents: int = 10, x_dim: int = 30, y_dim: int = 30, n_cities: int = 2, - max_rails_between_cities: int =2, - max_rails_in_city: int =3, - seed: int = 0, - malfunction_rate:float = 1/200, + max_rails_between_cities: int = 2, + max_rails_in_city: int = 3, + seed: int = 0, + malfunction_rate: float = 1 / 200, malfunction_min_duration: int = 20, malfunction_max_duration: int = 50, observation_max_path_depth: int = 30, observation_tree_depth: int = 2, - concat_prev_actions: bool = True, - concat_agent_id: bool = True, + concat_prev_actions: bool = True, + concat_agent_id: bool = False, evaluation: bool = False, random_seed: Optional[int] = None, ) -> FlatlandEnvWrapper: @@ -134,12 +138,12 @@ def make_environment( observation_max_path_depth=observation_max_path_depth, observation_tree_depth=observation_tree_depth, ) - + env = FlatlandEnvWrapper(env) if concat_prev_actions: env = ConcatPrevActionToObservation(env) - + if concat_agent_id: env = ConcatAgentIdToObservation(env) diff --git a/mava/utils/environments/smac_utils.py b/mava/utils/environments/smac_utils.py index 66b80b569..a1d6a4f31 100644 --- a/mava/utils/environments/smac_utils.py +++ b/mava/utils/environments/smac_utils.py @@ -12,19 +12,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mava.wrappers.env_preprocess_wrappers import ConcatAgentIdToObservation, ConcatPrevActionToObservation +from typing import Any, Optional + from smac.env import StarCraft2Env + from mava.wrappers import SMACWrapper +from mava.wrappers.env_preprocess_wrappers import ( + ConcatAgentIdToObservation, + ConcatPrevActionToObservation, +) + -def make_environment(map_name="3m", concat_prev_actions=True, concat_agent_id=True, evaluation = False, random_seed=None): +def make_environment( + map_name: str = "3m", + concat_prev_actions: bool = True, + concat_agent_id: bool = True, + evaluation: bool = False, + random_seed: Optional[int] = None, +) -> Any: env = StarCraft2Env(map_name=map_name, seed=random_seed) - + env = SMACWrapper(env) if concat_prev_actions: env = ConcatPrevActionToObservation(env) - + if concat_agent_id: env = ConcatAgentIdToObservation(env) - return env \ No newline at end of file + return env diff --git a/mava/utils/training_utils.py b/mava/utils/training_utils.py index 4eb15bc33..d6c399419 100644 --- a/mava/utils/training_utils.py +++ b/mava/utils/training_utils.py @@ -35,7 +35,7 @@ def decay_lr_actor_critic( def decay_lr( - lr_schedule: Optional[Callable[[int], None]], optimizers: Dict, trainer_step: int + lr_schedule: Optional[Callable], optimizers: Dict, trainer_step: int ) -> None: """Funtion that decays lr of optim. @@ -125,9 +125,10 @@ def map_losses_per_agent_ac(critic_losses: Dict, policy_losses: Dict) -> Dict: return logged_losses + # Map value losses to dict, grouped by agent. def map_losses_per_agent_value(value_losses: Dict) -> Dict: - assert len(value_losses) > 0 , "Invalid System Checkpointer." + assert len(value_losses) > 0, "Invalid System Checkpointer." logged_losses: Dict[str, Dict[str, Any]] = {} for agent in value_losses.keys(): logged_losses[agent] = { diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index b2a34fe6a..631f2a32c 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -15,13 +15,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union +import dm_env import gym import numpy as np from pettingzoo.utils import BaseParallelWraper from supersuit.utils.base_aec_wrapper import BaseWrapper -import dm_env -from mava.types import OLT -from mava.types import Action, Observation, Reward + +from mava.types import OLT, Action, Observation, Reward from mava.utils.wrapper_utils import RunningMeanStd from mava.wrappers.env_wrappers import ParallelEnvWrapper, SequentialEnvWrapper @@ -362,40 +362,45 @@ def _modify_observation(self, observation: Observation) -> Observation: def _modify_action(self, action: Action) -> Action: return action + class ConcatAgentIdToObservation: """Concat one-hot vector of agent ID to obs. - - We assume the environment has an ordered list + + We assume the environment has an ordered list self.possible_agents. """ - def __init__(self, environment): - self._environment = environment + def __init__(self, environment: Any) -> None: + self._environment = environment self._num_agents = len(environment.possible_agents) - def reset(self): + def reset(self) -> dm_env.TimeStep: timestep, extras = self._environment.reset() old_observations = timestep.observation - + new_observations = {} for agent_id, agent in enumerate(self._environment.possible_agents): agent_olt = old_observations[agent] - + agent_observation = agent_olt.observation - agent_one_hot = np.zeros(self._num_agents, dtype = agent_observation.dtype) + agent_one_hot = np.zeros(self._num_agents, dtype=agent_observation.dtype) agent_one_hot[agent_id] = 1 new_observations[agent] = OLT( - observation = np.concatenate([agent_one_hot, agent_observation]), - legal_actions= agent_olt.legal_actions, - terminal=agent_olt.terminal + observation=np.concatenate([agent_one_hot, agent_observation]), + legal_actions=agent_olt.legal_actions, + terminal=agent_olt.terminal, ) - return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras - + return ( + dm_env.TimeStep( + timestep.step_type, timestep.reward, timestep.discount, new_observations + ), + extras, + ) - def step(self, actions: Dict) -> Any: + def step(self, actions: Dict) -> dm_env.TimeStep: timestep, extras = self._environment.step(actions) old_observations = timestep.observation @@ -404,17 +409,21 @@ def step(self, actions: Dict) -> Any: agent_olt = old_observations[agent] agent_observation = agent_olt.observation - agent_one_hot = np.zeros(self._num_agents, dtype = agent_observation.dtype) + agent_one_hot = np.zeros(self._num_agents, dtype=agent_observation.dtype) agent_one_hot[agent_id] = 1 new_observations[agent] = OLT( - observation = np.concatenate([agent_one_hot, agent_observation]), + observation=np.concatenate([agent_one_hot, agent_observation]), legal_actions=agent_olt.legal_actions, - terminal=agent_olt.terminal + terminal=agent_olt.terminal, ) - - return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras + return ( + dm_env.TimeStep( + timestep.step_type, timestep.reward, timestep.discount, new_observations + ), + extras, + ) def observation_spec(self) -> Dict[str, OLT]: """Observation spec. @@ -422,7 +431,7 @@ def observation_spec(self) -> Dict[str, OLT]: Returns: types.Observation: spec for environment. """ - timestep, extras = self.reset() + timestep, extras = self.reset() observations = timestep.observation return observations @@ -443,61 +452,76 @@ def __getattr__(self, name: str) -> Any: class ConcatPrevActionToObservation: """Concat one-hot vector of agent prev_action to obs. - + We assume the environment has discreet actions. - TODO support continuous actions. + TODO (Claude) support continuous actions. """ + # Need to get the size of the action space of each agent - def __init__(self, environment): + def __init__(self, environment: Any): self._environment = environment - - def reset(self): + + def reset(self) -> dm_env.TimeStep: timestep, extras = self._environment.reset() old_observations = timestep.observation - action_spec = self._environment.action_spec() + action_spec = self._environment.action_spec() new_observations = {} - #TODO double check this, because possible agents could shrink + # TODO double check this, because possible agents could shrink for agent in self._environment.possible_agents: agent_olt = old_observations[agent] agent_observation = agent_olt.observation - agent_one_hot_action = np.zeros(action_spec[agent].num_values, dtype=np.float32) - + agent_one_hot_action = np.zeros( + action_spec[agent].num_values, dtype=np.float32 + ) + new_observations[agent] = OLT( - observation = np.concatenate([agent_one_hot_action, agent_observation]), - legal_actions= agent_olt.legal_actions, - terminal=agent_olt.terminal + observation=np.concatenate([agent_one_hot_action, agent_observation]), + legal_actions=agent_olt.legal_actions, + terminal=agent_olt.terminal, ) - return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras + return ( + dm_env.TimeStep( + timestep.step_type, timestep.reward, timestep.discount, new_observations + ), + extras, + ) - def step(self, actions: Dict) -> Any: + def step(self, actions: Dict) -> dm_env.TimeStep: timestep, extras = self._environment.step(actions) old_observations = timestep.observation - action_spec = self._environment.action_spec() + action_spec = self._environment.action_spec() new_observations = {} for agent in self._environment.possible_agents: agent_olt = old_observations[agent] agent_observation = agent_olt.observation - agent_one_hot_action = np.zeros(action_spec[agent].num_values, dtype=np.float32) + agent_one_hot_action = np.zeros( + action_spec[agent].num_values, dtype=np.float32 + ) agent_one_hot_action[actions[agent]] = 1 - + new_observations[agent] = OLT( - observation = np.concatenate([agent_one_hot_action, agent_observation]), - legal_actions= agent_olt.legal_actions, - terminal=agent_olt.terminal + observation=np.concatenate([agent_one_hot_action, agent_observation]), + legal_actions=agent_olt.legal_actions, + terminal=agent_olt.terminal, ) - return dm_env.TimeStep(timestep.step_type, timestep.reward, timestep.discount, new_observations), extras - + return ( + dm_env.TimeStep( + timestep.step_type, timestep.reward, timestep.discount, new_observations + ), + extras, + ) + def observation_spec(self) -> Dict[str, OLT]: """Observation spec. Returns: types.Observation: spec for environment. """ - timestep, extras = self.reset() + timestep, extras = self.reset() observations = timestep.observation return observations @@ -513,4 +537,4 @@ def __getattr__(self, name: str) -> Any: if hasattr(self.__class__, name): return self.__getattribute__(name) else: - return getattr(self._environment, name) \ No newline at end of file + return getattr(self._environment, name) diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index 5e32327ae..02f0032cc 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -27,8 +27,8 @@ try: from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv from flatland.envs.rail_env import RailEnv - from flatland.utils.rendertools import AgentRenderVariant, RenderTool from flatland.envs.step_utils.states import TrainState + from flatland.utils.rendertools import AgentRenderVariant, RenderTool except ModuleNotFoundError: pass from gym.spaces import Discrete @@ -135,7 +135,7 @@ def possible_agents(self) -> List[str]: """Return list of all possible agents.""" return self._possible_agents - def _update_stats(self, info, rewards): + def _update_stats(self, info: Dict, rewards: Dict) -> None: episode_return = sum(list(rewards.values())) tasks_finished = sum( [1 if state == TrainState.DONE else 0 for state in info["state"].values()] @@ -148,7 +148,7 @@ def _update_stats(self, info, rewards): self._latest_score = normalized_score self._latest_completion = completion - def get_stats(self): + def get_stats(self) -> Dict: if self._latest_completion is not None and self._latest_score is not None: return { "score": self._latest_score, @@ -248,12 +248,15 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: self._step_type = dm_env.StepType.MID discounts = self._discounts # discount == 1 - return dm_env.TimeStep( - observation=observations, - reward=rewards, - discount=discounts, - step_type=self._step_type, - ), {} + return ( + dm_env.TimeStep( + observation=observations, + reward=rewards, + discount=discounts, + step_type=self._step_type, + ), + {}, + ) # Convert Flatland observation so it's dm_env compatible. Also, the list # of legal actions must be converted to a legal actions mask. @@ -340,6 +343,10 @@ def observation_spec(self) -> Dict[str, OLT]: """Return observation spec.""" observation_specs = {} for agent in self.agents: + # Legal actions + action_spec = _convert_to_spec(self.action_spaces[agent]) + legals = np.ones(shape=action_spec.num_values, dtype=action_spec.dtype) + observation_specs[agent] = OLT( observation=tuple( ( @@ -349,7 +356,7 @@ def observation_spec(self) -> Dict[str, OLT]: ) if self._include_agent_info else _convert_to_spec(self.observation_spaces[agent]), - legal_actions=_convert_to_spec(self.action_spaces[agent]), + legal_actions=legals, terminal=specs.Array((1,), np.float32), ) return observation_specs @@ -410,7 +417,12 @@ def infer_observation_space( ) -> Union[Box, tuple, dict]: """Infer a gym Observation space from a sample observation from flatland""" if isinstance(obs, np.ndarray): - return Box(-np.inf, np.inf, shape=obs.shape, dtype=obs.dtype,) + return Box( + -np.inf, + np.inf, + shape=obs.shape, + dtype=obs.dtype, + ) elif isinstance(obs, tuple): return tuple(infer_observation_space(o) for o in obs) elif isinstance(obs, dict): @@ -610,4 +622,4 @@ def normalize_observation( normalized_obs = np.array( np.concatenate((np.concatenate((data, distance)), agent_data)), dtype=np.float32 ) - return normalized_obs \ No newline at end of file + return normalized_obs diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index 1c904da99..3043ada15 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -19,16 +19,13 @@ import dm_env import numpy as np from acme import specs - from smac.env import StarCraft2Env from mava import types -from mava.utils.wrapper_utils import ( - convert_np_type, - parameterized_restart, -) +from mava.utils.wrapper_utils import convert_np_type, parameterized_restart from mava.wrappers.env_wrappers import ParallelEnvWrapper + class SMACWrapper(ParallelEnvWrapper): """Environment wrapper for PettingZoo MARL environments.""" @@ -68,7 +65,9 @@ def reset(self) -> dm_env.TimeStep: # Get observation from env observation = self.environment.get_obs() legal_actions = self._get_legal_actions() - observations = self._convert_observations(observation, legal_actions, self._done) + observations = self._convert_observations( + observation, legal_actions, self._done + ) # Set env discount to 1 for all agents discount_spec = self.discount_spec() @@ -93,7 +92,6 @@ def reset(self) -> dm_env.TimeStep: return parameterized_restart(rewards, self._discounts, observations), extras - def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: """Steps in env. @@ -108,15 +106,17 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: return self.reset() # Convert dict of actions to list for SMAC - actions = list(actions.values()) + smac_actions = list(actions.values()) # Step the SMAC environment - reward, self._done, self._info = self._environment.step(actions) + reward, self._done, self._info = self._environment.step(smac_actions) # Get the next observations next_observations = self._environment.get_obs() legal_actions = self._get_legal_actions() - next_observations = self._convert_observations(next_observations, legal_actions, self._done) + next_observations = self._convert_observations( + next_observations, legal_actions, self._done + ) # Convert team reward to agent-wise rewards rewards = self._convert_reward(reward) @@ -167,16 +167,14 @@ def _convert_reward(self, reward: float) -> Dict[str, float]: rewards_spec = self.reward_spec() rewards = {} for agent in self._agents: - rewards[agent] = convert_np_type( - rewards_spec[agent].dtype, reward - ) + rewards[agent] = convert_np_type(rewards_spec[agent].dtype, reward) return rewards - def _get_legal_actions(self): + def _get_legal_actions(self) -> np.ndarray: legal_actions = [] for i, _ in enumerate(self._agents): legal_actions.append( - np.array(self._environment.get_avail_agent_actions(i), dtype='int') + np.array(self._environment.get_avail_agent_actions(i), dtype="int") ) return legal_actions @@ -194,7 +192,7 @@ def _convert_observations( """ olt_observations = {} for i, agent in enumerate(self._agents): - + olt_observations[agent] = types.OLT( observation=observations[i], legal_actions=legal_actions[i], @@ -221,7 +219,7 @@ def observation_spec(self) -> Dict[str, types.OLT]: types.Observation: spec for environment. """ self._environment.reset() - + observations = self._environment.get_obs() legal_actions = self._get_legal_actions() @@ -280,7 +278,7 @@ def get_stats(self) -> Optional[Dict]: extra stats to be logged. """ return self._environment.get_stats() - + @property def agents(self) -> List: """Agents still alive in env (not done). @@ -323,8 +321,6 @@ def __getattr__(self, name: str) -> Any: return getattr(self._environment, name) - - env = StarCraft2Env(map_name="3m") -wrapped_env = SMACWrapper(env) \ No newline at end of file +wrapped_env = SMACWrapper(env) From 0c9de4858a8f1961ba8a01674f5d494bcca6f0c5 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 15:54:49 +0200 Subject: [PATCH 15/56] Fix docstrings. --- mava/systems/tf/madqn/builder.py | 41 +++++++++++++++---- mava/systems/tf/madqn/system.py | 27 ++++++------ mava/systems/tf/madqn/training.py | 26 ++++++++---- mava/systems/tf/value_decomposition/system.py | 14 ++++--- .../tf/value_decomposition/training.py | 17 ++++---- 5 files changed, 81 insertions(+), 44 deletions(-) diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index 251d7ab6f..8c9269ede 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -27,7 +27,7 @@ from acme.utils import counting, loggers from dm_env import specs as dm_specs -from mava import adders, core, specs, types +from mava import Trainer, adders, core, specs, types from mava.adders import reverb as reverb_adders from mava.components.tf.modules.exploration.exploration_scheduling import ( BaseExplorationScheduler, @@ -136,20 +136,17 @@ class MADQNConfig: class MADQNBuilder: - """Builder for scaled MADDPG which constructs individual components of the - system.""" + """Builder for MADQN.""" def __init__( self, config: MADQNConfig, - trainer_fn: Union[ - Type[training.MADQNTrainer], - Type[training.MADQNRecurrentTrainer], - ] = training.MADQNTrainer, + trainer_fn: Type[Trainer] = training.MADQNTrainer, executor_fn: Type[core.Executor] = MADQNFeedForwardExecutor, extra_specs: Dict[str, Any] = {}, ): """Initialise the system. + Args: config: system configuration specifying hyperparameters and additional information for constructing the system. @@ -169,6 +166,15 @@ def __init__( self._executor_fn = executor_fn def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: + """Convert specs. + + Args: + spec: [description] + num_networks: [description] + + Returns: + Dict[str, Any]: converted specs + """ if type(spec) is not dict: return spec @@ -187,12 +193,15 @@ def make_replay_tables( self, environment_spec: specs.MAEnvironmentSpec, ) -> List[reverb.Table]: - """ "Create tables to insert data into. + """Create tables to insert data into. + Args: environment_spec: description of the action and observation spaces etc. for each agent in the system. + Raises: NotImplementedError: unknown executor type. + Returns: a list of data tables for inserting data. """ @@ -273,11 +282,14 @@ def make_dataset_iterator( table_name: str, ) -> Iterator[reverb.ReplaySample]: """Create a dataset iterator to use for training/updating the system. + Args: replay_client: Reverb Client which points to the replay server. + Returns: [type]: dataset iterator. + Yields: data samples from the dataset. """ @@ -303,11 +315,14 @@ def make_adder( replay_client: reverb.Client, ) -> Optional[adders.ParallelAdder]: """Create an adder which records data generated by the executor/environment. + Args: replay_client: Reverb Client which points to the replay server. + Raises: NotImplementedError: unknown executor type. + Returns: adder which sends data to a replay buffer. """ @@ -345,9 +360,11 @@ def create_counter_variables( self, variables: Dict[str, tf.Variable] ) -> Dict[str, tf.Variable]: """Create counter variables. + Args: variables: dictionary with variable_source variables in. + Returns: variables: dictionary with variable_source variables in. @@ -365,9 +382,11 @@ def make_variable_server( networks: Dict[str, Dict[str, snt.Module]], ) -> MavaVariableSource: """Create the variable server. + Args: networks: dictionary with the system's networks in. + Returns: variable_source: A Mava variable source object. """ @@ -409,6 +428,7 @@ def make_executor( evaluator: bool = False, ) -> core.Executor: """Create an executor instance. + Args: networks: dictionary with the system's networks in. policy_networks: policy networks for each agent in @@ -419,6 +439,7 @@ def make_executor( Defaults to None. evaluator: boolean indicator if the executor is used for for evaluation only. + Returns: system executor, a collection of agents making up the part of the system generating data by interacting the environment. @@ -494,6 +515,7 @@ def make_trainer( logger: Optional[types.NestedLogger] = None, ) -> core.Trainer: """Create a trainer instance. + Args: networks: system networks. dataset: dataset iterator to feed data to @@ -502,6 +524,7 @@ def make_trainer( trainer_networks: Set of unique network keys to train on.. trainer_table_entry: List of networks per agent to train on. logger: Logger object for logging metadata. + Returns: system trainer, that uses the collected data from the executors to update the parameters of the agent networks in the system. @@ -581,7 +604,7 @@ def make_trainer( } # The learner updates the parameters (and initializes them). - trainer = self._trainer_fn(**trainer_config) + trainer = self._trainer_fn(**trainer_config) # type: ignore trainer = ScaledDetailedTrainerStatistics( # type: ignore trainer, metrics=["value_loss"] diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 3042f76a7..92d52d6af 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -16,7 +16,7 @@ """MADQN system implementation.""" import functools -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union import acme import dm_env @@ -67,7 +67,7 @@ def __init__( # noqa ], logger_factory: Callable[[str], MavaLogger] = None, architecture: Type[DecentralisedValueActor] = DecentralisedValueActor, - trainer_fn: Type[training.MADQNTrainer] = training.MADQNTrainer, + trainer_fn: Type[mava.Trainer] = training.MADQNTrainer, executor_fn: Type[core.Executor] = MADQNFeedForwardExecutor, num_executors: int = 1, trainer_networks: Union[ @@ -117,16 +117,13 @@ def __init__( # noqa exploration. e.g. epsilon greedy logger_factory: function to instantiate a system logger. - architecture: - system architecture, e.g. decentralised or centralised. - trainer_fn: training type - associated with executor and architecture, e.g. centralised training. - executor_fn: executor type, e.g. - feedforward or recurrent. + architecture: system architecture, e.g. decentralised or centralised. + trainer_fn: training type associated with executor and architecture, + e.g. centralised training. + executor_fn: executor type, e.g. feedforward or recurrent. num_executors: number of executor processes to run in parallel. - trainer_networks: networks each - trainer trains on. + trainer_networks: networks each trainer trains on. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. enums.NetworkSampler settings: @@ -160,7 +157,7 @@ def __init__( # noqa max_replay_size: maximum replay size. samples_per_insert: number of samples to take from replay for every insert that is made. - optimizers: optimizer(s) for updating value networks. + optimizer: optimizer(s) for updating value networks. n_step: number of steps to include prior to boostrapping. sequence_length: recurrent sequence rollout length. period: Consecutive starting points for overlapping @@ -453,7 +450,7 @@ def replay(self) -> Any: def create_system( self, - ) -> Tuple[Dict[str, Dict[str, snt.Module]], Dict[str, Dict[str, snt.Module]]]: + ) -> Dict: """Initialise the system variables from the network factory.""" # Create the networks to optimize (online) networks = self._network_factory( # type: ignore @@ -487,12 +484,14 @@ def executor( replay: reverb.Client, variable_source: acme.VariableSource, ) -> mava.ParallelEnvironmentLoop: - """System executor + """System executor. + Args: executor_id: id to identify the executor process for logging purposes. replay: replay data table to push data to. variable_source: variable server for updating network variables. + Returns: mava.ParallelEnvironmentLoop: environment-executor loop instance. """ @@ -539,7 +538,7 @@ def evaluator( self, variable_source: acme.VariableSource, ) -> Any: - """System evaluator (an executor process not connected to a dataset) + """System evaluator. Args: variable_source: variable server for updating diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 5c2f9a831..1303b34b0 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -39,6 +39,7 @@ class MADQNTrainer(mava.Trainer): """MADQN trainer. + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -64,7 +65,8 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise MADQN trainer + """Initialise MADQN trainer. + Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". @@ -171,9 +173,11 @@ def __init__( self._timestamp: Optional[float] = None def _update_target_networks(self) -> None: + """Update the target networks. - """Update the target networks using either target averaging or - by directy copying the weights of the online networks every few steps.""" + Using either target averaging or + by directy copying the weights of the online networks every few steps. + """ for key in self.unique_net_keys: # Update target network. @@ -199,7 +203,6 @@ def _update_target_networks(self) -> None: self._num_steps.assign_add(1) def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: - """Depricated""" pass @@ -207,12 +210,12 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] def _transform_observations( self, obs: Dict[str, mava_types.OLT], next_obs: Dict[str, mava_types.OLT] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: - """Transform the observations using the observation networks of each agent." Args: obs: observations at timestep t-1 next_obs: observations at timestep t + Returns: Transformed observations """ @@ -255,7 +258,8 @@ def _step( # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: - """Trainer forward pass + """Trainer forward pass. + Args: inputs: input data from the data table (transitions) """ @@ -373,11 +377,13 @@ def _decay_lr(self, trainer_step: int) -> None: trainer_step : trainer step time t. """ train_utils.decay_lr( - self._learning_rate_scheduler_fn, self._optimizers, trainer_step + self._learning_rate_scheduler_fn, # type: ignore + self._optimizers, + trainer_step, ) -class MADQNRecurrentTrainer: +class MADQNRecurrentTrainer(mava.Trainer): """Recurrent MADQN trainer. This is the trainer component of a recurrent MADQN system. IE it takes a dataset @@ -769,5 +775,7 @@ def _decay_lr(self, trainer_step: int) -> None: trainer_step : trainer step time t. """ train_utils.decay_lr( - self._learning_rate_scheduler_fn, self._optimizers, trainer_step + self._learning_rate_scheduler_fn, # type: ignore + self._optimizers, + trainer_step, ) diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index c8321f091..c7cbc984e 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -89,7 +89,8 @@ def __init__( evaluator_interval: Optional[dict] = {"executor_episodes": 2}, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise the system + """Initialise the system. + Args: environment_factory: function to instantiate an environment. @@ -210,7 +211,7 @@ def __init__( if isinstance(mixer, str): if mixer == "qmix": - env = environment_factory() + env = environment_factory() # type: ignore num_agents = len(env.possible_agents) mixer = QMIX(num_agents) del env @@ -230,12 +231,14 @@ def trainer( replay: reverb.Client, variable_source: MavaVariableSource, ) -> mava.core.Trainer: - """System trainer + """System trainer. + Args: trainer_id: Id of the trainer being created. replay: replay data table to pull data from. variable_source: variable server for updating network variables. + Returns: system trainer. """ @@ -253,7 +256,7 @@ def trainer( # Create the dataset dataset = self._builder.make_dataset_iterator(replay, trainer_id) - trainer: ValueDecompositionRecurrentTrainer = self._builder.make_trainer( + trainer = self._builder.make_trainer( networks=networks, trainer_networks=self._trainer_networks[trainer_id], trainer_table_entry=self._table_network_config[trainer_id], @@ -262,6 +265,7 @@ def trainer( variable_source=variable_source, ) - trainer.setup_mixer(self._mixer, self._mixer_optimizer) + if isinstance(trainer, ValueDecompositionRecurrentTrainer): + trainer.setup_mixer(self._mixer, self._mixer_optimizer) return trainer diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index 8a957ed6b..49472ad08 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Value Decomposition trainer implementation.""" import copy @@ -63,7 +61,8 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise Value Decompostion trainer + """Initialise Value Decompostion trainer. + Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". @@ -120,7 +119,7 @@ def __init__( self._mixer_optimizer = None def setup_mixer(self, mixer: snt.Module, mixer_optimizer: snt.Module) -> None: - """Initialize the mixer network + """Initialize the mixer network. Args: mixer: mixer network @@ -131,8 +130,11 @@ def setup_mixer(self, mixer: snt.Module, mixer_optimizer: snt.Module) -> None: self._mixer_optimizer = mixer_optimizer def _update_target_networks(self) -> None: - """Update the target networks using either target averaging or - by directy copying the weights of the online networks every few steps.""" + """Update the target networks. + + Using either target averaging or + by directy copying the weights of the online networks every few steps. + """ online_variables = [] target_variables = [] @@ -170,7 +172,8 @@ def _update_target_networks(self) -> None: # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: - """Trainer forward pass + """Trainer forward pass. + Args: inputs: input data from the data table (transitions) """ From 83ee69587e576e04caf4092dbcbd7f565225e7a8 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 16:07:56 +0200 Subject: [PATCH 16/56] Fixed error caused by the typing fixes. --- mava/systems/tf/value_decomposition/system.py | 3 +-- mava/systems/tf/value_decomposition/training.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index c7cbc984e..2ca3eafc6 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -265,7 +265,6 @@ def trainer( variable_source=variable_source, ) - if isinstance(trainer, ValueDecompositionRecurrentTrainer): - trainer.setup_mixer(self._mixer, self._mixer_optimizer) + trainer.setup_mixer(self._mixer, self._mixer_optimizer) # type: ignore return trainer diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index 49472ad08..ee7e15ede 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -15,7 +15,7 @@ """Value Decomposition trainer implementation.""" import copy -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Union import reverb import sonnet as snt @@ -334,8 +334,7 @@ def _backward(self) -> None: self._optimizers[agent_key].apply(gradients, variables) # Mixer - if self._mixer is not None: - mixer_variables: Sequence[tf.Variable] = self._mixer.trainable_variables + mixer_variables = self._mixer.trainable_variables # type: ignore gradients = tape.gradient(mixer_loss, mixer_variables) From dcbac08a3ae39d98d925f529ba1448ed9884a46a Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 16:11:28 +0200 Subject: [PATCH 17/56] Small fix. --- examples/smac/recurrent/decentralised/run_madqn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 44c5c1382..825fa91ca 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -31,13 +31,10 @@ from mava.utils.environments.smac_utils import make_environment from mava.utils.loggers import logger_utils -SEQUENCE_LENGTH = 60 -MAP_NAME = "3m" - FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - MAP_NAME, + "3m", "Starcraft 2 micromanagement map name (str).", ) From de31b9c63a72a724afe1501b629c4743a05e5012 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 16:40:35 +0200 Subject: [PATCH 18/56] Docstring coverage. --- .../smac/recurrent/decentralised/run_madqn.py | 2 +- mava/systems/tf/madqn/execution.py | 37 ++++++++++++++----- mava/systems/tf/value_decomposition/system.py | 1 - mava/wrappers/flatland.py | 9 +++++ 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 825fa91ca..06b4e67d8 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Run example MADQN.""" import functools from datetime import datetime diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index f7762e6b4..4f8f598b7 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MADQN system executor implementation.""" from typing import Any, Dict, List, Optional, Tuple, Union @@ -43,7 +42,10 @@ class DQNExecutor: + """DQN executor.""" + def __init__(self, action_selectors: Dict) -> None: + """Initialise DQN executor.""" self._action_selectors = action_selectors def _get_epsilon(self) -> Union[float, np.ndarray]: @@ -88,6 +90,7 @@ def after_action_selection(self, time_t: int) -> None: def get_stats(self) -> Dict: """Return extra stats to log. + Returns: epsilon information. """ @@ -99,6 +102,7 @@ def get_stats(self) -> Dict: class MADQNFeedForwardExecutor(executors.FeedForwardExecutor, DQNExecutor): """A feed-forward executor for discrete actions. + An executor based on a feed-forward policy for each agent in the system. """ @@ -117,8 +121,8 @@ def __init__( variable_client: Optional[tf2_variable_utils.VariableClient] = None, interval: Optional[dict] = None, ): - """Initialise the system executor + Args: policy_networks: policy networks for each agent in the system. @@ -195,7 +199,7 @@ def _policy( def select_action( self, agent: str, observation: types.NestedArray ) -> Tuple[types.NestedArray, types.NestedArray]: - """select an action for a single agent in the system + """Select an action for a single agent in the system Args: agent: agent id. @@ -235,7 +239,8 @@ def observe_first( timestep: dm_env.TimeStep, extras: Dict[str, types.NestedArray] = {}, ) -> None: - """Record first observed timestep from the environment + """Record first observed timestep from the environment. + Args: timestep: data emitted by an environment at first step of interaction. @@ -265,7 +270,8 @@ def observe( next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record observed timestep from the environment + """Record observed timestep from the environment + Args: actions: system agents' actions. next_timestep: data emitted by an environment during @@ -308,9 +314,14 @@ def __init__( store_recurrent_state: bool = True, interval: Optional[dict] = None, ): - """Initialise the system executor + """Initialise the system executor. + Args: - policy_networks: policy networks for each agent in + action_selectors: epsilon greedy action selection + value_networks: agents value networks. + variable_client: client for managing + network variable distribution + observation_networks: observation networks for each agent in the system. agent_specs: agent observation and action space specifications. @@ -355,7 +366,8 @@ def _policy( legal_actions: types.NestedTensor, state: types.NestedTensor, ) -> Tuple: - """Agent specific policy function + """Agent specific policy function. + Args: agent: agent id observation: observation tensor received from the @@ -388,11 +400,13 @@ def _policy( def select_action( self, agent: str, observation: types.NestedArray ) -> types.NestedArray: - """select an action for a single agent in the system + """select an action for a single agent in the system. + Args: agent: agent id observation: observation tensor received from the environment. + Returns: action and policy. """ @@ -415,9 +429,11 @@ def select_action( def select_actions(self, observations: Dict[str, types.NestedArray]) -> Any: """select the actions for all agents in the system + Args: observations: agent observations from the environment. + Returns: actions and policies for all agents in the system. """ @@ -476,7 +492,8 @@ def observe( next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record observed timestep from the environment + """Record observed timestep from the environment. + Args: actions: system agents' actions. next_timestep: data emitted by an environment during diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 2ca3eafc6..feb102669 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Value Decomposition system implementation.""" from typing import Callable, Dict, Mapping, Optional, Type, Union diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index 02f0032cc..2f3bdd069 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -78,6 +78,7 @@ def __init__( agent_info: bool = False, ): """Wrap Flatland environment. + Args: environment: underlying RailEnv preprocessor: optional preprocessor. Defaults to None. @@ -136,6 +137,7 @@ def possible_agents(self) -> List[str]: return self._possible_agents def _update_stats(self, info: Dict, rewards: Dict) -> None: + """Update flatland stats.""" episode_return = sum(list(rewards.values())) tasks_finished = sum( [1 if state == TrainState.DONE else 0 for state in info["state"].values()] @@ -149,6 +151,7 @@ def _update_stats(self, info: Dict, rewards: Dict) -> None: self._latest_completion = completion def get_stats(self) -> Dict: + """Get flatland specific stats.""" if self._latest_completion is not None and self._latest_score is not None: return { "score": self._latest_score, @@ -263,6 +266,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: def _convert_observations( self, observes: Dict[str, Tuple[np.array, np.ndarray]], dones: Dict[str, bool] ) -> Observation: + """Convert observation""" return convert_dm_compatible_observations( observes, dones, @@ -276,6 +280,7 @@ def _convert_observations( def _collate_obs_and_info( self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] ) -> Dict[str, Tuple[np.array, np.ndarray]]: + """Combine observation and info.""" observations: Dict[str, Tuple[np.array, np.ndarray]] = {} observes = self.preprocessor(observes) for agent, obs in observes.items(): @@ -304,6 +309,7 @@ def _obtain_preprocessor( self, preprocessor: Any ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: """Obtains the actual preprocessor. + Obtains the actual preprocessor to be used based on the supplied preprocessor and the env's obs_builder object """ @@ -326,6 +332,7 @@ def _preprocessor( return x def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: + """Return preprocessor.""" temp_obs = {} for agent_id, ob in obs.items(): temp_obs[agent_id] = _preprocessor(ob) @@ -336,6 +343,7 @@ def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: # set all parameters that should be available before an environment step # if no available agent, then environment is done and should be reset def _pre_step(self) -> None: + """Pre-step.""" if not self.agents: self._step_type = dm_env.StepType.LAST @@ -511,6 +519,7 @@ def norm_obs_clip( normalize_to_range: bool = False, ) -> np.ndarray: """Normalize observation. + This function returns the difference between min and max value of an observation :param obs: Observation that should be normalized :param clip_min: min value where observation will be clipped From 805d0a7c96f68e3455786588167c504b905be7c5 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 16:41:12 +0200 Subject: [PATCH 19/56] Small change. --- examples/smac/recurrent/decentralised/run_madqn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 06b4e67d8..215c179cd 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -32,6 +32,7 @@ from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS + flags.DEFINE_string( "map_name", "3m", From fa70c50e278aa9344d320b34df30afed8989e8ec Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 26 Jan 2022 17:08:22 +0200 Subject: [PATCH 20/56] Remove DIAL tests. --- tests/networks/networks_test.py | 10 ++-- tests/systems/dial_system_test.py | 86 ------------------------------- tests/systems/qmix_system_test.py | 83 ----------------------------- tests/systems/vdn_system_test.py | 83 ----------------------------- 4 files changed, 5 insertions(+), 257 deletions(-) delete mode 100644 tests/systems/dial_system_test.py delete mode 100644 tests/systems/qmix_system_test.py delete mode 100644 tests/systems/vdn_system_test.py diff --git a/tests/networks/networks_test.py b/tests/networks/networks_test.py index 86821f264..a6f7e08be 100755 --- a/tests/networks/networks_test.py +++ b/tests/networks/networks_test.py @@ -30,7 +30,7 @@ DecentralisedValueActorCritic, ) from mava.components.tf.networks.continuous import LayerNormAndResidualMLP, LayerNormMLP -from mava.systems.tf import dial, mad4pg, maddpg, madqn, mappo, qmix, vdn +from mava.systems.tf import mad4pg, maddpg, madqn, mappo from mava.utils.environments import debugging_utils FLAGS = flags.FLAGS @@ -50,7 +50,7 @@ @pytest.mark.parametrize( "system", - [maddpg, mad4pg, madqn, mappo, qmix, vdn, dial], + [maddpg, mad4pg, madqn, mappo], ) class TestNetworkAgentKeys: """Test that we get the correct agent networks from make network functions.""" @@ -310,7 +310,7 @@ def test_network_reproducibility_0_same_seed(self, network: Any) -> None: """Test with same seed, networks are the same. Args: - system (Any): network. + network: network. """ test_seed = 42 @@ -337,7 +337,7 @@ def test_network_reproducibility_1_no_seed(self, network: Any) -> None: """Test with no seed, networks are different. Args: - system (Any): network. + network (Any): network. """ network, network_params = network @@ -360,7 +360,7 @@ def test_network_reproducibility_2_diff_seed(self, network: Any) -> None: """Test with diff seeds, networks are different. Args: - system (Any): network. + network (Any): network. """ network, network_params = network test_seed = 42 diff --git a/tests/systems/dial_system_test.py b/tests/systems/dial_system_test.py deleted file mode 100644 index 3032a6f03..000000000 --- a/tests/systems/dial_system_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for Dial.""" - -import functools - -import launchpad as lp -import sonnet as snt - -import mava -from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationTimestepScheduler, -) -from mava.systems.tf import dial -from mava.utils import lp_utils -from mava.utils.enums import ArchitectureType -from mava.utils.environments import debugging_utils - - -class TestDial: - """Simple integration/smoke test for dial.""" - - def test_recurrent_dial_on_debugging_env(self) -> None: - """Test recurrent dial.""" - # environment - environment_factory = functools.partial( - debugging_utils.make_environment, - env_name="simple_spread", - action_space="discrete", - ) - - # networks - network_factory = lp_utils.partial_kwargs( - dial.make_default_networks, - archecture_type=ArchitectureType.recurrent, - ) - - # system - system = dial.DIAL( - environment_factory=environment_factory, - network_factory=network_factory, - num_executors=1, - min_replay_size=16, - max_replay_size=1000, - batch_size=16, - optimizer=snt.optimizers.Adam(learning_rate=1e-3), - checkpoint=False, - sequence_length=3, - period=3, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=500 - ), - ) - - program = system.build() - - (trainer_node,) = program.groups["trainer"] - trainer_node.disable_run() - - # Launch gpu config - don't use gpu - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=[] - ) - lp.launch( - program, - launch_type="test_mt", - local_resources=local_resources, - ) - - trainer: mava.Trainer = trainer_node.create_handle().dereference() - - for _ in range(2): - trainer.step() diff --git a/tests/systems/qmix_system_test.py b/tests/systems/qmix_system_test.py deleted file mode 100644 index 5a16af7cb..000000000 --- a/tests/systems/qmix_system_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for QMIX.""" - -import functools - -import launchpad as lp -import sonnet as snt - -import mava -from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationTimestepScheduler, -) -from mava.systems.tf import qmix -from mava.utils import lp_utils -from mava.utils.environments import debugging_utils - - -class TestQmix: - """Simple integration/smoke test for qmix.""" - - def test_qmix_on_debugging_env(self) -> None: - """Test feedforward qmix.""" - # environment - environment_factory = functools.partial( - debugging_utils.make_environment, - env_name="simple_spread", - action_space="discrete", - return_state_info=True, - ) - - # networks - network_factory = lp_utils.partial_kwargs( - qmix.make_default_networks, policy_networks_layer_sizes=(64, 64) - ) - - # system - system = qmix.QMIX( - environment_factory=environment_factory, - network_factory=network_factory, - num_executors=1, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - optimizer=snt.optimizers.Adam(learning_rate=1e-3), - checkpoint=False, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=500 - ), - ) - - program = system.build() - - (trainer_node,) = program.groups["trainer"] - trainer_node.disable_run() - - # Launch gpu config - don't use gpu - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=[] - ) - lp.launch( - program, - launch_type="test_mt", - local_resources=local_resources, - ) - - trainer: mava.Trainer = trainer_node.create_handle().dereference() - - for _ in range(2): - trainer.step() diff --git a/tests/systems/vdn_system_test.py b/tests/systems/vdn_system_test.py deleted file mode 100644 index 4639139e6..000000000 --- a/tests/systems/vdn_system_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for VDN.""" - -import functools - -import launchpad as lp -import sonnet as snt - -import mava -from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationTimestepScheduler, -) -from mava.systems.tf import vdn -from mava.utils import lp_utils -from mava.utils.environments import debugging_utils - - -class TestVdn: - """Simple integration/smoke test for Vdn.""" - - def test_vdn_on_debugging_env(self) -> None: - """Test feedforward vdn.""" - # environment - environment_factory = functools.partial( - debugging_utils.make_environment, - env_name="simple_spread", - action_space="discrete", - return_state_info=True, - ) - - # networks - network_factory = lp_utils.partial_kwargs( - vdn.make_default_networks, policy_networks_layer_sizes=(64, 64) - ) - - # system - system = vdn.VDN( - environment_factory=environment_factory, - network_factory=network_factory, - num_executors=1, - batch_size=32, - min_replay_size=32, - max_replay_size=1000, - optimizer=snt.optimizers.Adam(learning_rate=1e-3), - checkpoint=False, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=500 - ), - ) - - program = system.build() - - (trainer_node,) = program.groups["trainer"] - trainer_node.disable_run() - - # Launch gpu config - don't use gpu - local_resources = lp_utils.to_device( - program_nodes=program.groups.keys(), nodes_on_gpu=[] - ) - lp.launch( - program, - launch_type="test_mt", - local_resources=local_resources, - ) - - trainer: mava.Trainer = trainer_node.create_handle().dereference() - - for _ in range(2): - trainer.step() From 38ace5287d2aab24d9975f013dad9504512fb4d3 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 12:10:28 +0200 Subject: [PATCH 21/56] Fix test. --- tests/networks/networks_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/networks/networks_test.py b/tests/networks/networks_test.py index a6f7e08be..7ec24541d 100755 --- a/tests/networks/networks_test.py +++ b/tests/networks/networks_test.py @@ -183,7 +183,9 @@ def test_network_seed_is_passed(self, system: Any) -> None: [ dict( network_mapping={ - "value_networks": "q_networks", + "value_networks": "values", + "observation_networks": "observations", + "action_selectors": "action_selectors", }, architecture=DecentralisedValueActor, system=madqn, From 4f609930aac8f4c99d2164f89bfe095b22acb83b Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 12:57:32 +0200 Subject: [PATCH 22/56] Formatting fixes. --- .github/ISSUE_TEMPLATE/bug_report.md | 8 ++++---- docs/images/focus_fire.html | 2 +- docs/images/runaway.html | 2 +- mava/components/tf/architectures/decentralised.py | 14 ++++++++++---- mava/components/tf/modules/mixing/__init__.py | 2 +- mava/components/tf/modules/stabilising/__init__.py | 2 +- .../tf/modules/stabilising/fingerprints.py | 2 +- mava/components/tf/networks/__init__.py | 2 +- mava/systems/tf/madqn/__init__.py | 5 +---- mava/wrappers/debugging_envs.py | 12 +++++++----- mava/wrappers/meltingpot.py | 1 - 11 files changed, 28 insertions(+), 24 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 3073db92c..72ed0566b 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -12,10 +12,10 @@ A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: -1. -2. -3. -4. +1. +2. +3. +4. **Expected behavior** A clear and concise description of what you expected to happen. diff --git a/docs/images/focus_fire.html b/docs/images/focus_fire.html index b84a48bac..028d72abe 100644 --- a/docs/images/focus_fire.html +++ b/docs/images/focus_fire.html @@ -428,4 +428,4 @@ MTAw "> Your browser does not support the video tag. - \ No newline at end of file + diff --git a/docs/images/runaway.html b/docs/images/runaway.html index f52c7de2a..d92bb9189 100644 --- a/docs/images/runaway.html +++ b/docs/images/runaway.html @@ -1272,4 +1272,4 @@ dAAAACWpdG9vAAAAHWRhdGEAAAABAAAAAExhdmY1OC4yOS4xMDA= "> Your browser does not support the video tag. - \ No newline at end of file + diff --git a/mava/components/tf/architectures/decentralised.py b/mava/components/tf/architectures/decentralised.py index 020ebd0a8..8138e2309 100644 --- a/mava/components/tf/architectures/decentralised.py +++ b/mava/components/tf/architectures/decentralised.py @@ -77,7 +77,7 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: "target_values": {}, "observations": {}, "target_observations": {}, - "selectors": {} + "selectors": {}, } # get actor specs @@ -88,12 +88,18 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: agent_net_key = self._agent_net_keys[agent_key] obs_spec = actor_obs_specs[agent_key] # Create variables for observation and value networks. - embed = tf2_utils.create_variables(self._observation_networks[agent_net_key], [obs_spec]) + embed = tf2_utils.create_variables( + self._observation_networks[agent_net_key], [obs_spec] + ) tf2_utils.create_variables(self._value_networks[agent_net_key], [embed]) # Create target value and observation network variables - embed = tf2_utils.create_variables(self._target_observation_networks[agent_net_key], [obs_spec]) - tf2_utils.create_variables(self._target_value_networks[agent_net_key], [embed]) + embed = tf2_utils.create_variables( + self._target_observation_networks[agent_net_key], [obs_spec] + ) + tf2_utils.create_variables( + self._target_value_networks[agent_net_key], [embed] + ) actor_networks["values"] = self._value_networks actor_networks["target_values"] = self._target_value_networks diff --git a/mava/components/tf/modules/mixing/__init__.py b/mava/components/tf/modules/mixing/__init__.py index 425ad888d..324178853 100644 --- a/mava/components/tf/modules/mixing/__init__.py +++ b/mava/components/tf/modules/mixing/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. """Value decomposition mixing modules.""" -from mava.components.tf.modules.mixing.mixers import BaseMixer, VDN, QMIX +from mava.components.tf.modules.mixing.mixers import QMIX, VDN, BaseMixer diff --git a/mava/components/tf/modules/stabilising/__init__.py b/mava/components/tf/modules/stabilising/__init__.py index 6ccea8f93..b71bd102d 100644 --- a/mava/components/tf/modules/stabilising/__init__.py +++ b/mava/components/tf/modules/stabilising/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""MARL experience replay stabilising modules.""" \ No newline at end of file +"""MARL experience replay stabilising modules.""" diff --git a/mava/components/tf/modules/stabilising/fingerprints.py b/mava/components/tf/modules/stabilising/fingerprints.py index 33d599bb2..1aedc6bc2 100644 --- a/mava/components/tf/modules/stabilising/fingerprints.py +++ b/mava/components/tf/modules/stabilising/fingerprints.py @@ -14,4 +14,4 @@ # limitations under the License. """Experience replay stabilisation with fingerprints""" -# TODO (Claude) implement fingerprints for new MADQN system. \ No newline at end of file +# TODO (Claude) implement fingerprints for new MADQN system. diff --git a/mava/components/tf/networks/__init__.py b/mava/components/tf/networks/__init__.py index 304bac090..45367cda8 100644 --- a/mava/components/tf/networks/__init__.py +++ b/mava/components/tf/networks/__init__.py @@ -30,4 +30,4 @@ from mava.components.tf.networks.mad4pg import ( DiscreteValuedDistribution, DiscreteValuedHead, -) \ No newline at end of file +) diff --git a/mava/systems/tf/madqn/__init__.py b/mava/systems/tf/madqn/__init__.py index f156e6be4..a3a9159a7 100644 --- a/mava/systems/tf/madqn/__init__.py +++ b/mava/systems/tf/madqn/__init__.py @@ -21,7 +21,4 @@ ) from mava.systems.tf.madqn.networks import make_default_networks from mava.systems.tf.madqn.system import MADQN -from mava.systems.tf.madqn.training import ( - MADQNRecurrentTrainer, - MADQNTrainer, -) +from mava.systems.tf.madqn.training import MADQNRecurrentTrainer, MADQNTrainer diff --git a/mava/wrappers/debugging_envs.py b/mava/wrappers/debugging_envs.py index 16ce51673..27363aa67 100644 --- a/mava/wrappers/debugging_envs.py +++ b/mava/wrappers/debugging_envs.py @@ -125,7 +125,9 @@ def _convert_observations( # accordingly if isinstance(self._environment.action_spaces[agent], spaces.Discrete): legals = np.ones( - _convert_to_spec(self._environment.action_spaces[agent]).num_values, + _convert_to_spec( + self._environment.action_spaces[agent] + ).num_values, dtype=self._environment.action_spaces[agent].dtype, ) else: @@ -149,10 +151,10 @@ def observation_spec(self) -> Dict[str, OLT]: # Legals spec if isinstance(self._environment.action_spaces[agent], spaces.Discrete): - legals = np.ones( - _convert_to_spec(self._environment.action_spaces[agent]).num_values, - dtype=self._environment.action_spaces[agent].dtype, - ) + legals = np.ones( + _convert_to_spec(self._environment.action_spaces[agent]).num_values, + dtype=self._environment.action_spaces[agent].dtype, + ) else: legals = np.ones( _convert_to_spec(self._environment.action_spaces[agent]).shape, diff --git a/mava/wrappers/meltingpot.py b/mava/wrappers/meltingpot.py index 1f7d9c9ae..17338969f 100644 --- a/mava/wrappers/meltingpot.py +++ b/mava/wrappers/meltingpot.py @@ -27,7 +27,6 @@ try: import pygame # type: ignore - from meltingpot.python.scenario import Scenario # type: ignore from meltingpot.python.substrate import Substrate # type: ignore except ModuleNotFoundError: From 9adb2047a4ee074eb6ae28ba2a171a291bddcdc1 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 12:58:41 +0200 Subject: [PATCH 23/56] Small docstring fix. --- mava/systems/tf/madqn/networks.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index bf2fca57e..35c8698e9 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -39,25 +39,17 @@ def make_default_networks( atari_torso_observation_network: bool = False, seed: Optional[int] = None, ) -> Mapping[str, types.TensorTransformation]: - """Default networks for maddpg. + """Default networks for madqn. Args: environment_spec: description of the action and observation spaces etc. for each agent in the system. agent_net_keys: specifies what network each agent uses. - vmin: hyperparameters for the distributional critic in mad4pg. - vmax: hyperparameters for the distributional critic in mad4pg. net_spec_keys: specifies the specs of each network. - policy_networks_layer_sizes: size of policy networks. - critic_networks_layer_sizes: size of critic networks. - sigma: hyperparameters used to add Gaussian noise - for simple exploration. Defaults to 0.3. + value_networks_layer_sizes: size of value networks. archecture_type: archecture used for agent networks. Can be feedforward or recurrent. Defaults to ArchitectureType.feedforward. - - num_atoms: hyperparameters for the distributional critic in - mad4pg. seed: random seed for network initialization. Returns: From 10485273226118caf9cc2ce106368b9b8fce016e Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 13:48:11 +0200 Subject: [PATCH 24/56] Typing errors. --- .../run_madqn_custom_lr_schedule.py | 2 +- examples/meltingpot/test_on_scenarios.py | 1 - examples/meltingpot/train_on_substrates.py | 1 - mava/environment_loop.py | 3 +- tests/utils/environment_utils_test.py | 44 ++++++++----------- 5 files changed, 21 insertions(+), 30 deletions(-) diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py index f337a2cbb..c9d4f94d7 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py @@ -109,7 +109,7 @@ def main(_: Any) -> None: ), optimizer=snt.optimizers.Adam(learning_rate=lr_start), checkpoint_subpath=checkpoint_dir, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, + learning_rate_scheduler_fn=learning_rate_scheduler_fn, #typing: ignore ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/examples/meltingpot/test_on_scenarios.py b/examples/meltingpot/test_on_scenarios.py index 3183fa1fa..43b91588b 100644 --- a/examples/meltingpot/test_on_scenarios.py +++ b/examples/meltingpot/test_on_scenarios.py @@ -77,7 +77,6 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_min=0.05, epsilon_decay=1e-4 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() diff --git a/examples/meltingpot/train_on_substrates.py b/examples/meltingpot/train_on_substrates.py index cd0092480..9b8f8ce19 100644 --- a/examples/meltingpot/train_on_substrates.py +++ b/examples/meltingpot/train_on_substrates.py @@ -76,7 +76,6 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_min=0.05, epsilon_decay=1e-4 ), - importance_sampling_exponent=0.2, optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() diff --git a/mava/environment_loop.py b/mava/environment_loop.py index ca22a5817..34690112a 100644 --- a/mava/environment_loop.py +++ b/mava/environment_loop.py @@ -545,7 +545,8 @@ def should_run_loop(eval_condtion: Tuple) -> bool: # Log the given results. self._logger.write(result) else: - # Note: We assume that the evaluator will be running less than once per second. + # Note: We assume that the evaluator will be running less + # than once per second. time.sleep(1) # We need to get the latest counts if we are using eval intervals. if environment_loop_schedule: diff --git a/tests/utils/environment_utils_test.py b/tests/utils/environment_utils_test.py index d780ecf7a..bde316a8e 100644 --- a/tests/utils/environment_utils_test.py +++ b/tests/utils/environment_utils_test.py @@ -20,25 +20,13 @@ from mava.utils.environments import debugging_utils, pettingzoo_utils try: - from flatland.envs.observations import TreeObsForRailEnv - from flatland.envs.rail_generators import sparse_rail_generator - from flatland.envs.schedule_generators import sparse_schedule_generator - - from mava.utils.environments.flatland_utils import flatland_env_factory - + import flatland + from mava.utils.environments import flatland_utils _has_flatland = True except (ModuleNotFoundError, ImportError): _has_flatland = False pass -if _has_flatland: - rail_gen_cfg: Dict = { - "max_num_cities": 4, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "grid_mode": True, - } - @pytest.mark.parametrize( "env", @@ -47,17 +35,21 @@ (debugging_utils.make_environment, {}), (pettingzoo_utils.make_environment, {}), ( - flatland_env_factory, + flatland_utils.make_environment, { - "env_config": { - "number_of_agents": 2, - "width": 25, - "height": 25, - "rail_generator": sparse_rail_generator(**rail_gen_cfg), - "schedule_generator": sparse_schedule_generator(), - "obs_builder_object": TreeObsForRailEnv(max_depth=2), - } - }, + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, + } ) if _has_flatland else None, @@ -111,7 +103,7 @@ def test_env_reproducibility_1_no_seed_different_observation( # This test doesn't work with flatland and SC2, since FL uses # a default seed (1) and SC2 (5), even when a seed is not provided. - if _has_flatland and env_factory == flatland_env_factory: + if _has_flatland and env_factory == flatland_utils.make_environment: pytest.skip("Skipping no seed test for flatland and SC2.") wrapped_env = env_factory(**env_params) @@ -150,7 +142,7 @@ def test_env_reproducibility_1_different_seed_different_observation( # This test doesn't work with flatland, since FL seeds # at ini for SparseRailGen . - if _has_flatland and env_factory == flatland_env_factory: + if _has_flatland and env_factory == flatland_utils.make_environment: pytest.skip("Skipping diff seed test for flatland.") wrapped_env = env_factory(random_seed=test_seed1, **env_params) From 57aba3aeeda427584d7d34e5e275d2bf6dd79b4a Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 13:54:48 +0200 Subject: [PATCH 25/56] More typing errors. --- .../run_madqn_custom_lr_schedule.py | 2 +- mava/environment_loop.py | 2 +- tests/conftest.py | 35 ++++++++----------- tests/utils/environment_utils_test.py | 6 ++-- 4 files changed, 19 insertions(+), 26 deletions(-) diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py index c9d4f94d7..44457324d 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py @@ -109,7 +109,7 @@ def main(_: Any) -> None: ), optimizer=snt.optimizers.Adam(learning_rate=lr_start), checkpoint_subpath=checkpoint_dir, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, #typing: ignore + learning_rate_scheduler_fn=learning_rate_scheduler_fn, # typing: ignore ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/mava/environment_loop.py b/mava/environment_loop.py index 34690112a..875890d50 100644 --- a/mava/environment_loop.py +++ b/mava/environment_loop.py @@ -545,7 +545,7 @@ def should_run_loop(eval_condtion: Tuple) -> bool: # Log the given results. self._logger.write(result) else: - # Note: We assume that the evaluator will be running less + # Note: We assume that the evaluator will be running less # than once per second. time.sleep(1) # We need to get the latest counts if we are using eval intervals. diff --git a/tests/conftest.py b/tests/conftest.py index 0ff3e4845..fde5faea4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,11 +24,7 @@ import pytest try: - from flatland.envs.observations import TreeObsForRailEnv - from flatland.envs.rail_generators import sparse_rail_generator - from flatland.envs.schedule_generators import sparse_schedule_generator - - from mava.utils.environments.flatland_utils import load_flatland_env + from mava.utils.environments import flatland_utils from mava.wrappers.flatland import FlatlandEnvWrapper _has_flatland = True @@ -63,23 +59,20 @@ if _has_flatland: # flatland environment config - rail_gen_cfg: Dict = { - "max_num_cities": 4, + flatland_env_config = { + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, "max_rails_between_cities": 2, "max_rails_in_city": 3, - "grid_mode": True, - "seed": 42, - } - - flatland_env_config: Dict = { - "number_of_agents": 2, - "width": 25, - "height": 25, - "rail_generator": sparse_rail_generator(**rail_gen_cfg), - "schedule_generator": sparse_schedule_generator(), - "obs_builder_object": TreeObsForRailEnv(max_depth=2), - } - + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, + }, """ Helpers contains re-usable test functions. @@ -122,7 +115,7 @@ def get_env(env_spec: EnvSpec) -> Union[AECEnv, ParallelEnv]: elif env_spec.env_type == EnvType.Sequential: env = mod.env() # type:ignore elif env_spec.env_source == EnvSource.Flatland: - env = load_flatland_env(flatland_env_config) + env = flatland_utils.make_environment(**flatland_env_config) elif env_spec.env_source == EnvSource.OpenSpiel: env = load_open_spiel_env(env_spec.env_name) else: diff --git a/tests/utils/environment_utils_test.py b/tests/utils/environment_utils_test.py index bde316a8e..defa7bca4 100644 --- a/tests/utils/environment_utils_test.py +++ b/tests/utils/environment_utils_test.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import numpy as np import pytest @@ -20,8 +20,8 @@ from mava.utils.environments import debugging_utils, pettingzoo_utils try: - import flatland from mava.utils.environments import flatland_utils + _has_flatland = True except (ModuleNotFoundError, ImportError): _has_flatland = False @@ -49,7 +49,7 @@ "malfunction_max_duration": 50, "observation_max_path_depth": 30, "observation_tree_depth": 2, - } + }, ) if _has_flatland else None, From 94c2c838e7656d0f6cd06cb20618beeecbb6d8bf Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 13:57:30 +0200 Subject: [PATCH 26/56] Typo. --- .../run_madqn_custom_lr_schedule.py | 2 +- tests/conftest.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py index 44457324d..eeac75457 100644 --- a/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py @@ -109,7 +109,7 @@ def main(_: Any) -> None: ), optimizer=snt.optimizers.Adam(learning_rate=lr_start), checkpoint_subpath=checkpoint_dir, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, # typing: ignore + learning_rate_scheduler_fn=learning_rate_scheduler_fn, # type: ignore ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/tests/conftest.py b/tests/conftest.py index fde5faea4..ad51f76da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,19 +60,19 @@ if _has_flatland: # flatland environment config flatland_env_config = { - "n_agents": 3, - "x_dim": 30, - "y_dim": 30, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "seed": 0, - "malfunction_rate": 1 / 200, - "malfunction_min_duration": 20, - "malfunction_max_duration": 50, - "observation_max_path_depth": 30, - "observation_tree_depth": 2, - }, + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, + } """ Helpers contains re-usable test functions. From 0de21ffbf778cab35b1ef3fb747dbac99309bce5 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 14:00:07 +0200 Subject: [PATCH 27/56] Type ignore in test. --- tests/conftest.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ad51f76da..ab2bb05de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,19 +60,19 @@ if _has_flatland: # flatland environment config flatland_env_config = { - "n_agents": 3, - "x_dim": 30, - "y_dim": 30, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "seed": 0, - "malfunction_rate": 1 / 200, - "malfunction_min_duration": 20, - "malfunction_max_duration": 50, - "observation_max_path_depth": 30, - "observation_tree_depth": 2, - } + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, + } """ Helpers contains re-usable test functions. @@ -115,7 +115,7 @@ def get_env(env_spec: EnvSpec) -> Union[AECEnv, ParallelEnv]: elif env_spec.env_type == EnvType.Sequential: env = mod.env() # type:ignore elif env_spec.env_source == EnvSource.Flatland: - env = flatland_utils.make_environment(**flatland_env_config) + env = flatland_utils.make_environment(**flatland_env_config) #type:ignore elif env_spec.env_source == EnvSource.OpenSpiel: env = load_open_spiel_env(env_spec.env_name) else: From edad908ed07e905aa1044f9dde91d20a72d12a1e Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 14:00:51 +0200 Subject: [PATCH 28/56] Docstrings. --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ab2bb05de..459e6957c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,7 +115,7 @@ def get_env(env_spec: EnvSpec) -> Union[AECEnv, ParallelEnv]: elif env_spec.env_type == EnvType.Sequential: env = mod.env() # type:ignore elif env_spec.env_source == EnvSource.Flatland: - env = flatland_utils.make_environment(**flatland_env_config) #type:ignore + env = flatland_utils.make_environment(**flatland_env_config) # type:ignore elif env_spec.env_source == EnvSource.OpenSpiel: env = load_open_spiel_env(env_spec.env_name) else: From 30661d1c9606382f9202501d964a35dc706bebfc Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 14:43:17 +0200 Subject: [PATCH 29/56] Flatland wrapper import error. --- mava/wrappers/flatland.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index 2f3bdd069..fbf807500 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -23,14 +23,10 @@ import numpy as np from acme import specs from acme.wrappers.gym_wrapper import _convert_to_spec - -try: - from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv - from flatland.envs.rail_env import RailEnv - from flatland.envs.step_utils.states import TrainState - from flatland.utils.rendertools import AgentRenderVariant, RenderTool -except ModuleNotFoundError: - pass +from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.step_utils.states import TrainState +from flatland.utils.rendertools import AgentRenderVariant, RenderTool from gym.spaces import Discrete from gym.spaces.box import Box @@ -46,6 +42,7 @@ class FlatlandEnvWrapper(ParallelEnvWrapper): """Environment wrapper for Flatland environments. + All environments would require an observation preprocessor, except for 'GlobalObsForRailEnv'. This is because flatland gives users the flexibility of designing custom observation builders. 'TreeObsForRailEnv' @@ -460,6 +457,7 @@ def get_agent_handle(id: str) -> int: def decorate_step_method(env: RailEnv) -> None: """Step method decorator. + Enable the step method of the env to take action dictionaries where agent keys are the agent ids. Flatland uses the agent handles as keys instead. This function decorates the step method so that it accepts an action dict where the keys are the @@ -485,6 +483,7 @@ def _step( def max_lt(seq: Sequence, val: Any) -> Any: """Get max in sequence. + Return greatest item in seq for which item < val applies. None is returned if seq was empty or all items in seq were >= val. """ @@ -499,6 +498,7 @@ def max_lt(seq: Sequence, val: Any) -> Any: def min_gt(seq: Sequence, val: Any) -> Any: """Gets min in a sequence. + Return smallest item in seq for which item > val applies. None is returned if seq was empty or all items in seq were >= val. """ From 476b1b5f07e9543c949490117256610a05b6e160 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Thu, 27 Jan 2022 15:50:11 +0200 Subject: [PATCH 30/56] Fix mypy issues. --- mava/utils/wrapper_utils.py | 2 +- mava/wrappers/flatland.py | 31 ++++++++++++++++++------------- mava/wrappers/smac.py | 4 ++-- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mava/utils/wrapper_utils.py b/mava/utils/wrapper_utils.py index e336ba67a..2fbb1bf7c 100644 --- a/mava/utils/wrapper_utils.py +++ b/mava/utils/wrapper_utils.py @@ -25,7 +25,7 @@ def convert_dm_compatible_observations( - observes: Dict[str, np.ndarray], + observes: Dict, dones: Dict[str, bool], observation_spec: Dict[str, types.OLT], env_done: bool, diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index fbf807500..31a19a0fa 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -17,7 +17,7 @@ import types as tp from functools import partial -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union import dm_env import numpy as np @@ -157,7 +157,7 @@ def get_stats(self) -> Dict: else: return {} - def render(self, mode: str = "human") -> np.array: + def render(self, mode: str = "human") -> np.ndarray: """Renders the environment.""" if mode == "human": show = True @@ -194,11 +194,11 @@ def reset(self) -> dm_env.TimeStep: } discount_spec = self.discount_spec() - self._discounts = { + discounts = { agent: convert_np_type(discount_spec[agent].dtype, 1) for agent in self.possible_agents } - return parameterized_restart(rewards, self._discounts, observations), {} + return parameterized_restart(rewards, discounts, observations), {} def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: """Steps the environment.""" @@ -246,7 +246,12 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: # TODO (Claude) zero discount! else: self._step_type = dm_env.StepType.MID - discounts = self._discounts # discount == 1 + discounts = { + agent: convert_np_type( + self.discount_spec()[agent].dtype, 1 + ) # discount = 1 + for agent in self.possible_agents + } return ( dm_env.TimeStep( @@ -261,7 +266,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: # Convert Flatland observation so it's dm_env compatible. Also, the list # of legal actions must be converted to a legal actions mask. def _convert_observations( - self, observes: Dict[str, Tuple[np.array, np.ndarray]], dones: Dict[str, bool] + self, observes: Dict[str, Tuple[np.ndarray, np.ndarray]], dones: Dict[str, bool] ) -> Observation: """Convert observation""" return convert_dm_compatible_observations( @@ -276,17 +281,17 @@ def _convert_observations( # be a tuple of the observation from the env and the agent info def _collate_obs_and_info( self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] - ) -> Dict[str, Tuple[np.array, np.ndarray]]: + ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: """Combine observation and info.""" - observations: Dict[str, Tuple[np.array, np.ndarray]] = {} + observations: Dict = {} observes = self.preprocessor(observes) for agent, obs in observes.items(): agent_id = get_agent_id(agent) agent_info = np.array( [info[k][agent] for k in sort_str_num(info.keys())], dtype=np.float32 ) - obs = (obs, agent_info) if self._include_agent_info else obs - observations[agent_id] = obs + new_obs = (obs, agent_info) if self._include_agent_info else obs + observations[agent_id] = new_obs return observations @@ -481,7 +486,7 @@ def _step( # serve as the default preprocessor for the Tree obs builder. -def max_lt(seq: Sequence, val: Any) -> Any: +def max_lt(seq: np.ndarray, val: Any) -> Any: """Get max in sequence. Return greatest item in seq for which item < val applies. @@ -496,7 +501,7 @@ def max_lt(seq: Sequence, val: Any) -> Any: return max -def min_gt(seq: Sequence, val: Any) -> Any: +def min_gt(seq: np.ndarray, val: Any) -> Any: """Gets min in a sequence. Return smallest item in seq for which item > val applies. @@ -569,7 +574,7 @@ def _split_node_into_feature_groups( def _split_subtree_into_feature_groups( node: Node, current_tree_depth: int, max_tree_depth: int -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple: """Split subtree.""" if node == -np.inf: remaining_depth = max_tree_depth - current_tree_depth diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index 3043ada15..46429d64f 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Wraps a PettingZoo MARL environment to be used as a dm_env environment.""" +"""Wraper for SMAC.""" from typing import Any, Dict, List, Optional, Union import dm_env @@ -170,7 +170,7 @@ def _convert_reward(self, reward: float) -> Dict[str, float]: rewards[agent] = convert_np_type(rewards_spec[agent].dtype, reward) return rewards - def _get_legal_actions(self) -> np.ndarray: + def _get_legal_actions(self) -> List: legal_actions = [] for i, _ in enumerate(self._agents): legal_actions.append( From 97bae04e8f13fdce196e698ed8544fa136b56b3a Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 28 Jan 2022 10:54:06 +0200 Subject: [PATCH 31/56] Fixed tf.function bug. Big system speed-up. --- .../recurrent/decentralised/run_madqn.py | 3 +- .../smac/recurrent/centralised/run_qmix.py | 2 +- .../smac/recurrent/decentralised/run_madqn.py | 2 +- .../decentralised/run_madqn_scale_trainers.py | 1 + mava/systems/tf/mad4pg/training.py | 27 ++++ mava/systems/tf/madqn/execution.py | 116 ++++++++---------- mava/systems/tf/madqn/training.py | 2 +- .../environments/RoboCup_env/__init__.py | 1 + .../robocup_utils/player_world_model.py | 8 +- .../robocup_utils/trainer_world_model.py | 16 ++- mava/wrappers/env_wrappers.py | 8 +- 11 files changed, 101 insertions(+), 85 deletions(-) diff --git a/examples/flatland/recurrent/decentralised/run_madqn.py b/examples/flatland/recurrent/decentralised/run_madqn.py index d6f7de0c8..b3285721b 100644 --- a/examples/flatland/recurrent/decentralised/run_madqn.py +++ b/examples/flatland/recurrent/decentralised/run_madqn.py @@ -43,7 +43,7 @@ # flatland environment config env_config: Dict = { - "n_agents": 3, + "n_agents": 10, "x_dim": 30, "y_dim": 30, "n_cities": 2, @@ -59,6 +59,7 @@ def main(_: Any) -> None: + """Run example""" # Environment. environment_factory = functools.partial(make_environment, **env_config) diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index f0e42dbcb..624a39cd0 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -33,7 +33,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + "8m", "Starcraft 2 micromanagement map name (str).", ) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 215c179cd..b8d1030c5 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -35,7 +35,7 @@ flags.DEFINE_string( "map_name", - "3m", + "10m_vs_11m", "Starcraft 2 micromanagement map name (str).", ) diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index e759c08c6..c4b52a385 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -47,6 +47,7 @@ def main(_: Any) -> None: + """Main function.""" # Environment. environment_factory = functools.partial( diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index d23b45fb0..31735e95f 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -707,6 +707,7 @@ def __init__( bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): + """Init trainer.""" super().__init__( agents=agents, @@ -764,6 +765,7 @@ def __init__( bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): + """Init trainer.""" super().__init__( agents=agents, @@ -820,6 +822,31 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, ): + """Initialise trainer. + + Args: + agents: [description] + agent_types: [description] + policy_networks: [description] + critic_networks: [description] + target_policy_networks: [description] + target_critic_networks: [description] + policy_optimizer: [description] + critic_optimizer: [description] + discount: [description] + target_averaging: [description] + target_update_period: [description] + target_update_rate: [description] + dataset: [description] + observation_networks: [description] + target_observation_networks: [description] + variable_client: [description] + counts: [description] + agent_net_keys: [description] + max_gradient_norm: [description]. Defaults to None. + logger: [description]. Defaults to None. + bootstrap_n: [description]. Defaults to 10. + """ super().__init__( agents=agents, diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 4f8f598b7..979c29832 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -20,6 +20,7 @@ import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +import tree from acme import types from acme.specs import EnvironmentSpec @@ -124,8 +125,11 @@ def __init__( """Initialise the system executor Args: - policy_networks: policy networks for each agent in + value_networks: value networks for each agent in the system. + observation_networks: value networks for each agent in + the system. + action_selectors: eg. epsilon_greedy action selector agent_specs: agent observation and action space specifications. agent_net_keys: specifies what network each agent uses. @@ -135,8 +139,7 @@ def __init__( adder: adder which sends data to a replay buffer. Defaults to None. counts: Count values used to record excutor episode and steps. - variable_client: - client to copy weights from the trainer. Defaults to None. + variable_client: client to copy weights from the trainer. Defaults to None. evaluator: whether the executor will be used for evaluation. interval: interval that evaluations are run at. @@ -157,7 +160,6 @@ def __init__( self._adder = adder self._variable_client = variable_client - @tf.function def _policy( self, agent: str, @@ -196,28 +198,20 @@ def _policy( return action - def select_action( - self, agent: str, observation: types.NestedArray - ) -> Tuple[types.NestedArray, types.NestedArray]: - """Select an action for a single agent in the system - - Args: - agent: agent id. - observation: observation tensor received - from the environment. - - Returns: - agent action and policy. - """ - # Get the action from the policy, conditioned on the observation - action = self._policy(agent, observation.observation, observation.legal_actions) - - # Return a numpy array with squeezed out batch dimension. - action = tf2_utils.to_numpy_squeeze(action) - - return action + @tf.function + def _select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> types.NestedArray: + actions = {} + for agent, observation in observations.items(): + actions[agent] = self._policy( + agent, observation.observation, observation.legal_actions + ) + return actions - def select_actions(self, observations: Dict[str, types.NestedArray]) -> Dict: + def select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> types.NestedArray: """Select the actions for all agents in the system Args: @@ -228,9 +222,9 @@ def select_actions(self, observations: Dict[str, types.NestedArray]) -> Dict: actions and policies for all agents in the system. """ - actions = {} - for agent, observation in observations.items(): - actions[agent] = self.select_action(agent, observation) + actions = self._select_actions(observations) + + actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions) return actions @@ -250,7 +244,7 @@ def observe_first( if not self._adder: return - "Select new networks from the sampler at the start of each episode." + # Select new networks from the sampler at the start of each episode. agents = sort_str_num(list(self._agent_net_keys.keys())) self._network_int_keys_extras, self._agent_net_keys = sample_new_agent_keys( agents, @@ -358,7 +352,6 @@ def __init__( self._action_selectors = action_selectors self._states: Dict[str, Any] = {} - @tf.function def _policy( self, agent: str, @@ -373,8 +366,10 @@ def _policy( observation: observation tensor received from the environment. state: recurrent network state. + Raises: NotImplementedError: unknown action space + Returns: action, policy and new recurrent hidden state """ @@ -397,49 +392,44 @@ def _policy( return action, new_state - def select_action( - self, agent: str, observation: types.NestedArray + @tf.function + def _select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> types.NestedArray: + actions: Dict = {} + new_states: Dict = {} + for agent, observation in observations.items(): + actions[agent], + new_states[agent] = self._policy( + agent, + observation.observation, + observation.legal_actions, + self._states[agent], + ) + return actions, new_states + + def select_actions( + self, observations: Dict[str, types.NestedArray] ) -> types.NestedArray: - """select an action for a single agent in the system. + """Select the actions for all agents in the system Args: - agent: agent id - observation: observation tensor received from the + observations: agent observations from the environment. Returns: - action and policy. + actions and policies for all agents in the system. """ - # Step the recurrent policy forward given the current observation and state. - action, new_state = self._policy( - agent, - observation.observation, - observation.legal_actions, - self._states[agent], - ) - - # Bookkeeping of recurrent states for the observe method. - self._update_state(agent, new_state) - - # Return a numpy array with squeezed out batch dimension. - action = tf2_utils.to_numpy_squeeze(action) - - return action + actions, new_states = self._select_actions(observations) - def select_actions(self, observations: Dict[str, types.NestedArray]) -> Any: - """select the actions for all agents in the system + # Convert actions to numpy arrays + actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions) - Args: - observations: agent observations from the - environment. + # Update agent core state + for agent, state in new_states.items(): + self._update_state(agent, state) - Returns: - actions and policies for all agents in the system. - """ - actions = {} - for agent, observation in observations.items(): - actions[agent] = self.select_action(agent, observation) return actions def observe_first( @@ -447,7 +437,7 @@ def observe_first( timestep: dm_env.TimeStep, extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record first observed timestep from the environment + """Record first observed timestep from the environment Args: timestep: data emitted by an environment at first step of diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 1303b34b0..24cb00008 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -597,10 +597,10 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] """Depricated""" pass - # @tf.function # NOTE (Claude) The recurrent trainer does not start with tf.function # It does start on SMAC 3m and debug env but not on any other SMAC maps. # TODO (Claude) get tf.function to work. + @tf.function def _step( self, ) -> Dict[str, Dict[str, Any]]: diff --git a/mava/utils/environments/RoboCup_env/__init__.py b/mava/utils/environments/RoboCup_env/__init__.py index e61d07c59..f1a7820aa 100644 --- a/mava/utils/environments/RoboCup_env/__init__.py +++ b/mava/utils/environments/RoboCup_env/__init__.py @@ -12,3 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Robocup env.""" diff --git a/mava/utils/environments/RoboCup_env/robocup_utils/player_world_model.py b/mava/utils/environments/RoboCup_env/robocup_utils/player_world_model.py index 92057e173..e1a7b62d1 100755 --- a/mava/utils/environments/RoboCup_env/robocup_utils/player_world_model.py +++ b/mava/utils/environments/RoboCup_env/robocup_utils/player_world_model.py @@ -362,14 +362,10 @@ def turn_body_to_object(self, obj): class ServerParameters: - """ - A storage container for all the settings of the soccer server. - """ + """A storage container for all the settings of the soccer server.""" def __init__(self): - """ - Initialize default parameters for a server. - """ + """Initialize default parameters for a server.""" self.audio_cut_dist = 50 self.auto_mode = 0 diff --git a/mava/utils/environments/RoboCup_env/robocup_utils/trainer_world_model.py b/mava/utils/environments/RoboCup_env/robocup_utils/trainer_world_model.py index 75cb46630..b6aa9ee2a 100755 --- a/mava/utils/environments/RoboCup_env/robocup_utils/trainer_world_model.py +++ b/mava/utils/environments/RoboCup_env/robocup_utils/trainer_world_model.py @@ -15,8 +15,9 @@ # type: ignore class WorldModel: - """ - Holds and updates the model of the world as known from current and past + """Holds and updates the model of the world + + As known from current and past data. """ @@ -50,8 +51,9 @@ def __init__(self): ) class RefereeMessages: - """ - Static class containing possible non-mode messages sent by a referee. + """Static class containing possible non-mode + + messages sent by a referee. """ # these are referee messages, not play modes @@ -108,8 +110,9 @@ def __init__(self, action_handler): self.server_parameters = ServerParameters() def process_new_info(self, ball, goals, players): - """ - Update any internal variables based on the currently available + """Update any internal variables. + + Based on the currently available information. This also calculates information not available directly from server-reported messages, such as player coordinates. """ @@ -149,6 +152,7 @@ def process_new_info(self, ball, goals, players): self.players = players def get_state(self): + """Get state.""" return {"ball": self.ball, "players": self.players} def is_playon(self): diff --git a/mava/wrappers/env_wrappers.py b/mava/wrappers/env_wrappers.py index 8138e87ae..6ef23d2ff 100644 --- a/mava/wrappers/env_wrappers.py +++ b/mava/wrappers/env_wrappers.py @@ -20,15 +20,11 @@ class ParallelEnvWrapper(dm_env.Environment): - """ - Abstract class for parallel environment wrappers. - """ + """Abstract class for parallel environment wrappers""" @abstractmethod def env_done(self) -> bool: - """ - Returns a bool indicating if all agents in env are done. - """ + """Returns a bool indicating if env is done""" @property @abstractmethod From 018665a11d00c05ecb6cf942956aeec261be1759 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 28 Jan 2022 11:05:45 +0200 Subject: [PATCH 32/56] Doc strings. --- mava/systems/tf/mad4pg/builder.py | 1 - mava/systems/tf/mad4pg/execution.py | 5 ++- mava/systems/tf/mad4pg/system.py | 5 ++- mava/systems/tf/mad4pg/training.py | 34 +++++++++++++++++-- mava/systems/tf/maddpg/training.py | 51 ++++++++++++++++++++++------- 5 files changed, 77 insertions(+), 19 deletions(-) diff --git a/mava/systems/tf/mad4pg/builder.py b/mava/systems/tf/mad4pg/builder.py index 092b1f63d..baca8c809 100644 --- a/mava/systems/tf/mad4pg/builder.py +++ b/mava/systems/tf/mad4pg/builder.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MAD4PG system builder implementation.""" from typing import Any, Dict, Type, Union diff --git a/mava/systems/tf/mad4pg/execution.py b/mava/systems/tf/mad4pg/execution.py index e820327fd..e27057029 100644 --- a/mava/systems/tf/mad4pg/execution.py +++ b/mava/systems/tf/mad4pg/execution.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MAD4PG system executor implementation.""" from typing import Any, Dict, List, Optional @@ -30,6 +29,7 @@ class MAD4PGFeedForwardExecutor(MADDPGFeedForwardExecutor): """A feed-forward executor for MAD4PG. + An executor based on a feed-forward policy for each agent in the system. """ @@ -48,6 +48,7 @@ def __init__( ): """Initialise the system executor + Args: policy_networks: policy networks for each agent in the system. @@ -83,6 +84,7 @@ def __init__( class MAD4PGRecurrentExecutor(MADDPGRecurrentExecutor): """A recurrent executor for MAD4PG. + An executor based on a recurrent policy for each agent in the system. """ @@ -100,6 +102,7 @@ def __init__( interval: Optional[dict] = None, ): """Initialise the system executor + Args: policy_networks: policy networks for each agent in the system. diff --git a/mava/systems/tf/mad4pg/system.py b/mava/systems/tf/mad4pg/system.py index b0554fe96..906acf5a7 100644 --- a/mava/systems/tf/mad4pg/system.py +++ b/mava/systems/tf/mad4pg/system.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MAD4PG system implementation.""" from typing import Callable, Dict, List, Optional, Type, Union @@ -35,8 +34,8 @@ class MAD4PG(MADDPG): """MAD4PG system.""" - """TODO: Implement faster adders to speed up training times when - using multiple trainers with non-shared weights.""" + # TODO: Implement faster adders to speed up training times when + # using multiple trainers with non-shared weights. def __init__( self, diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index 31735e95f..aeed32a86 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """MAD4PG system trainer implementation.""" from typing import Any, Callable, Dict, List, Optional, Union @@ -48,6 +46,7 @@ class MAD4PGBaseTrainer(MADDPGBaseTrainer): """MAD4PG trainer. + This is the trainer component of a MAD4PG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -113,7 +112,6 @@ def __init__( returns the current learning rate. """ - """Initialise the decentralised MADDPG trainer.""" super().__init__( agents=agents, agent_types=agent_types, @@ -255,6 +253,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the decentralised MAD4PG trainer.""" + super().__init__( agents=agents, agent_types=agent_types, @@ -308,6 +307,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the centralised MAD4PG trainer.""" + super().__init__( agents=agents, agent_types=agent_types, @@ -361,6 +361,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the state-based MAD4PG trainer.""" + super().__init__( agents=agents, agent_types=agent_types, @@ -388,6 +389,7 @@ def __init__( class MAD4PGBaseRecurrentTrainer(MADDPGBaseRecurrentTrainer): """Recurrent MAD4PG trainer. + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -650,6 +652,32 @@ def __init__( bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): + """Init trainer. + + Args: + agents: [description] + agent_types: [description] + policy_networks: [description] + critic_networks: [description] + target_policy_networks: [description] + target_critic_networks: [description] + policy_optimizer: [description] + critic_optimizer: [description] + discount: [description] + target_averaging: [description] + target_update_period: [description] + target_update_rate: [description] + dataset: [description] + observation_networks: [description] + target_observation_networks: [description] + variable_client: [description] + counts: [description] + agent_net_keys: [description] + max_gradient_norm: [description]. Defaults to None. + logger: [description]. Defaults to None. + bootstrap_n: [description]. Defaults to 10. + learning_rate_scheduler_fn: [description]. Defaults to None. + """ super().__init__( agents=agents, diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 00fcfd5ba..8fb792fa6 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -43,6 +43,7 @@ class MADDPGBaseTrainer(mava.Trainer): """MADDPG trainer. + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -72,6 +73,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise MADDPG trainer + Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". @@ -198,6 +200,7 @@ def __init__( def _update_target_networks(self) -> None: """Update the target networks using either target averaging or + by directy copying the weights of the online networks every few steps.""" for key in self.unique_net_keys: # Update target network. @@ -236,6 +239,7 @@ def _transform_observations( Args: obs: observations at timestep t-1 next_obs: observations at timestep t + Returns: Transformed observatations """ @@ -264,7 +268,7 @@ def _get_critic_feed( e_t: Dict[str, np.ndarray], agent: str, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: - """get data to feed to the agent critic network(s) + """Get data to feed to the agent critic network(s) Args: o_tm1_trans: transformed (e.g. using observation @@ -293,7 +297,7 @@ def _get_dpg_feed( dpg_a_t: np.ndarray, agent: str, ) -> tf.Tensor: - """get data to feed to the agent networks + """Get data to feed to the agent networks Args: a_t: action at timestep t @@ -308,7 +312,7 @@ def _get_dpg_feed( return dpg_a_t_feed def _target_policy_actions(self, next_obs: Dict[str, np.ndarray]) -> Any: - """select actions using target policy networks + """Select actions using target policy networks Args: next_obs: next agent observations. @@ -351,6 +355,7 @@ def _step( # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass + Args: inputs: input data from the data table (transitions) """ @@ -484,7 +489,7 @@ def _backward(self) -> None: train_utils.safe_del(self, "tape") def step(self) -> None: - """trainer step to update the parameters of the agents in the system""" + """Trainer step to update the parameters of the agents in the system""" # Run the learning step. fetches = self._step() @@ -515,6 +520,7 @@ def step(self) -> None: def after_trainer_step(self) -> None: """Optionally decay lr after every training step.""" + if self._learning_rate_scheduler_fn: self._decay_lr(self._num_steps) info: Dict[str, Dict[str, float]] = {} @@ -624,6 +630,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the centralised MADDPG trainer.""" + super().__init__( agents=agents, agent_types=agent_types, @@ -658,6 +665,7 @@ def _get_critic_feed( e_t: Dict[str, np.ndarray], agent: str, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + """Get critic feed.""" # Centralised based o_tm1_feed = tf.stack([o_tm1_trans[agent] for agent in self._agents], 1) @@ -673,6 +681,7 @@ def _get_dpg_feed( dpg_a_t: np.ndarray, agent: str, ) -> tf.Tensor: + """Get DPG feed.""" # Centralised and StateBased DPG # Note (dries): Copy has to be made because the input @@ -717,6 +726,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the networked MADDPG trainer.""" + super().__init__( agents=agents, agent_types=agent_types, @@ -752,6 +762,7 @@ def _get_critic_feed( e_t: Dict[str, np.ndarray], agent: str, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + """Get critic feed.""" # Networked based connections = self._connection_spec[agent] @@ -781,6 +792,7 @@ def _get_dpg_feed( dpg_a_t: np.ndarray, agent: str, ) -> tf.Tensor: + """Get DPG feed.""" # Networked based tree.map_structure(tf.stop_gradient, a_t) @@ -861,6 +873,7 @@ def _get_critic_feed( e_t: Dict[str, np.ndarray], agent: str, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + """Get critic feed.""" # State based o_tm1_feed = e_tm1["s_t"] @@ -876,6 +889,7 @@ def _get_dpg_feed( dpg_a_t: np.ndarray, agent: str, ) -> tf.Tensor: + """Get DPG feed.""" # Centralised and StateBased DPG # Note (dries): Copy has to be made because the input @@ -893,6 +907,7 @@ def _get_dpg_feed( class MADDPGBaseRecurrentTrainer(mava.Trainer): """Recurrent MADDPG trainer. + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -923,6 +938,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise Recurrent MADDPG trainer + Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". @@ -1050,6 +1066,7 @@ def __init__( def _update_target_networks(self) -> None: """Sync the target parameters with the latest online + parameters for all networks""" for key in self.unique_net_keys: @@ -1080,7 +1097,8 @@ def _update_target_networks(self) -> None: def _transform_observations( self, observations: Dict[str, mava_types.OLT] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: - """apply the observation networks to the raw observations from the dataset + """Apply the observation networks to the raw observations from the dataset + Args: obs: raw agent observations next_obs: raw next observations @@ -1127,7 +1145,8 @@ def _get_critic_feed( extras: Dict[str, np.ndarray], agent: str, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: - """get data to feed to the agent critic network(s) + """Get data to feed to the agent critic network(s) + Args: o_tm1_trans: transformed (e.g. using observation network) observation at timestep t-1 @@ -1137,6 +1156,7 @@ def _get_critic_feed( e_tm1: extras at timestep t-1 e_t: extras at timestep t agent: agent id + Returns: agent critic network feeds @@ -1155,11 +1175,13 @@ def _get_dpg_feed( dpg_actions: np.ndarray, agent: str, ) -> tf.Tensor: - """get data to feed to the agent networks + """Get data to feed to the agent networks + Args: a_t: action at timestep t dpg_a_t: predicted action at timestep t agent: agent id + Returns: tf.Tensor: agent policy network feed """ @@ -1173,11 +1195,13 @@ def _target_policy_actions( target_obs_trans: Dict[str, np.ndarray], target_core_state: Dict[str, np.ndarray], ) -> Any: - """select actions using target policy networks + """Select actions using target policy networks + Args: target_obs_trans: agent transformed target observations. target_core_state: target recurrent network state + Returns: Any: agent target actions """ @@ -1206,7 +1230,8 @@ def _target_policy_actions( def _step( self, ) -> Dict[str, Dict[str, Any]]: - """Trainer forward and backward passes. + """Trainer forward and backward passes.' + Returns: losses """ @@ -1229,6 +1254,7 @@ def _step( # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass + Args: inputs: input data from the data table (transitions) """ @@ -1413,7 +1439,7 @@ def _backward(self) -> None: train_utils.safe_del(self, "tape") def step(self) -> None: - """trainer step to update the parameters of the agents in the system""" + """Trainer step to update the parameters of the agents in the system""" # Run the learning step. fetches = self._step() @@ -1443,7 +1469,8 @@ def step(self) -> None: self._logger.write(fetches) def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: - """get network variables + """Get network variables + Args: names: network names Returns: @@ -1492,6 +1519,7 @@ def _decay_lr(self, trainer_step: int) -> None: class MADDPGDecentralisedRecurrentTrainer(MADDPGBaseRecurrentTrainer): """Recurrent MADDPG trainer for a decentralised architecture. + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -1550,6 +1578,7 @@ def __init__( class MADDPGCentralisedRecurrentTrainer(MADDPGBaseRecurrentTrainer): """Recurrent MADDPG trainer for a centralised architecture. + This is the trainer component of a MADDPG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ From 3216e41b57c3d9e374b4a0444f84633392ad8d5b Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Fri, 28 Jan 2022 11:06:37 +0200 Subject: [PATCH 33/56] Added tests for VDN and QMIX. Fixed MADQN test. --- tests/systems/madqn_system_test.py | 9 +-- tests/systems/qmix_system_test.py | 93 ++++++++++++++++++++++++++++++ tests/systems/vdn_system_test.py | 90 +++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 4 deletions(-) create mode 100644 tests/systems/qmix_system_test.py create mode 100644 tests/systems/vdn_system_test.py diff --git a/tests/systems/madqn_system_test.py b/tests/systems/madqn_system_test.py index a2c58a798..0b4bfce5e 100644 --- a/tests/systems/madqn_system_test.py +++ b/tests/systems/madqn_system_test.py @@ -44,7 +44,7 @@ def test_madqn_on_debugging_env(self) -> None: # networks network_factory = lp_utils.partial_kwargs( - madqn.make_default_networks, policy_networks_layer_sizes=(64, 64) + madqn.make_default_networks, value_networks_layer_sizes=(64, 64) ) # system @@ -83,7 +83,8 @@ def test_madqn_on_debugging_env(self) -> None: trainer.step() def test_recurrent_madqn_on_debugging_env(self) -> None: - """Test recurrent maddpg.""" + """Test recurrent madqn.""" + # environment environment_factory = functools.partial( debugging_utils.make_environment, @@ -94,8 +95,8 @@ def test_recurrent_madqn_on_debugging_env(self) -> None: # networks network_factory = lp_utils.partial_kwargs( madqn.make_default_networks, - archecture_type=ArchitectureType.recurrent, - policy_networks_layer_sizes=(32, 32), + architecture_type=ArchitectureType.recurrent, + value_networks_layer_sizes=(32, 32), ) # system diff --git a/tests/systems/qmix_system_test.py b/tests/systems/qmix_system_test.py new file mode 100644 index 000000000..3949ed094 --- /dev/null +++ b/tests/systems/qmix_system_test.py @@ -0,0 +1,93 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools + +import launchpad as lp +import sonnet as snt + +import mava +from mava.components.tf.modules.exploration.exploration_scheduling import ( + LinearExplorationScheduler, +) +from mava.systems.tf import value_decomposition +from mava.utils import lp_utils +from mava.utils.environments import debugging_utils + +"""Test for QMIX System""" + + +class TestQMIX: + """Simple integration test for QMIX on Simple Spread enviromnent""" + + def test_qmix_on_debug_simple_spread(self) -> None: + """Test recurrent QMIX.""" + + # environment + environment_factory = functools.partial( + debugging_utils.make_environment, + env_name="simple_spread", + action_space="discrete", + return_state_info=True, + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + value_decomposition.make_default_networks, + ) + + # system + system = value_decomposition.ValueDecomposition( + environment_factory=environment_factory, + network_factory=network_factory, + mixer="qmix", + num_executors=1, + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + ), + optimizer=snt.optimizers.RMSProp( + learning_rate=0.0005, epsilon=0.00001, decay=0.99 + ), + batch_size=1, + executor_variable_update_period=200, + target_update_period=200, + max_gradient_norm=20.0, + min_replay_size=1, + max_replay_size=10000, + samples_per_insert=None, + evaluator_interval={"executor_episodes": 2}, + checkpoint=False, + ) + + program = system.build() + + (trainer_node,) = program.groups["trainer"] + trainer_node.disable_run() + + # Launch gpu config - don't use gpu + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=[] + ) + lp.launch( + program, + launch_type="test_mt", + local_resources=local_resources, + ) + + trainer: mava.Trainer = trainer_node.create_handle().dereference() + + for _ in range(2): + trainer.step() diff --git a/tests/systems/vdn_system_test.py b/tests/systems/vdn_system_test.py new file mode 100644 index 000000000..5adf5373d --- /dev/null +++ b/tests/systems/vdn_system_test.py @@ -0,0 +1,90 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools + +import launchpad as lp +import sonnet as snt + +import mava +from mava.components.tf.modules.exploration.exploration_scheduling import ( + LinearExplorationScheduler, +) +from mava.systems.tf import value_decomposition +from mava.utils import lp_utils +from mava.utils.environments import debugging_utils + +"""Test for VDN System""" + + +class TestVDN: + """Simple integration test for VDN on the debug enviroment""" + + def test_vdn_on_debug_simple_spread(self) -> None: + """Test vdn on simple spread environment.""" + + # environment + environment_factory = functools.partial( + debugging_utils.make_environment, + env_name="simple_spread", + action_space="discrete", + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + value_decomposition.make_default_networks, + ) + + # system + system = value_decomposition.ValueDecomposition( + environment_factory=environment_factory, + network_factory=network_factory, + mixer="vdn", + num_executors=1, + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + ), + optimizer=snt.optimizers.RMSProp( + learning_rate=0.0005, epsilon=0.00001, decay=0.99 + ), + batch_size=1, + max_gradient_norm=20.0, + min_replay_size=1, + max_replay_size=10000, + samples_per_insert=None, + evaluator_interval={"executor_episodes": 2}, + checkpoint=False, + ) + + program = system.build() + + (trainer_node,) = program.groups["trainer"] + trainer_node.disable_run() + + # Launch gpu config - don't use gpu + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=[] + ) + lp.launch( + program, + launch_type="test_mt", + local_resources=local_resources, + ) + + trainer: mava.Trainer = trainer_node.create_handle().dereference() + + for _ in range(2): + trainer.step() From 48f902a561328db32082069d7f129481846a3a9f Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 28 Jan 2022 13:40:11 +0200 Subject: [PATCH 34/56] Fixes. --- examples/smac/recurrent/centralised/run_qmix.py | 8 +++++--- examples/smac/recurrent/centralised/run_vdn.py | 6 ++++-- examples/smac/recurrent/decentralised/run_madqn.py | 8 +++++--- mava/systems/tf/madqn/execution.py | 9 ++++----- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index 624a39cd0..767741034 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -33,7 +33,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "8m", + "1c3s5z", "Starcraft 2 micromanagement map name (str).", ) @@ -78,7 +78,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -89,8 +89,10 @@ def main(_: Any) -> None: target_update_period=200, max_gradient_norm=20.0, min_replay_size=32, - max_replay_size=10000, + max_replay_size=5000, samples_per_insert=16, + sequence_length=200, + period=200, evaluator_interval={"executor_episodes": 2}, ).build() diff --git a/examples/smac/recurrent/centralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py index 00afb0c03..ae0df4e31 100644 --- a/examples/smac/recurrent/centralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -33,7 +33,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "3m", + "1c3s5z", "Starcraft 2 micromanagement map name (str).", ) @@ -78,7 +78,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -89,6 +89,8 @@ def main(_: Any) -> None: min_replay_size=32, max_replay_size=10000, samples_per_insert=16, + sequence_length=200, + period=200, evaluator_interval={"executor_episodes": 2}, ).build() diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index b8d1030c5..b907dafe9 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -35,7 +35,7 @@ flags.DEFINE_string( "map_name", - "10m_vs_11m", + "1c3s5z", "Starcraft 2 micromanagement map name (str).", ) @@ -79,7 +79,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -90,8 +90,10 @@ def main(_: Any) -> None: target_update_period=200, max_gradient_norm=20.0, min_replay_size=32, - max_replay_size=10000, + max_replay_size=5000, samples_per_insert=16, + sequence_length=200, + period=200, evaluator_interval={"executor_episodes": 2}, trainer_fn=madqn.MADQNRecurrentTrainer, executor_fn=madqn.MADQNRecurrentExecutor, diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 979c29832..484b09d16 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -394,17 +394,16 @@ def _policy( @tf.function def _select_actions( - self, observations: Dict[str, types.NestedArray] + self, observations: Dict[str, types.NestedArray], states ) -> types.NestedArray: actions: Dict = {} new_states: Dict = {} for agent, observation in observations.items(): - actions[agent], - new_states[agent] = self._policy( + actions[agent], new_states[agent] = self._policy( agent, observation.observation, observation.legal_actions, - self._states[agent], + states[agent], ) return actions, new_states @@ -421,7 +420,7 @@ def select_actions( actions and policies for all agents in the system. """ - actions, new_states = self._select_actions(observations) + actions, new_states = self._select_actions(observations, self._states) # Convert actions to numpy arrays actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions) From fe3ddaae27f9694fbe90f0f7d4025a3c4b05b07e Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 31 Jan 2022 11:02:38 +0200 Subject: [PATCH 35/56] Fix old comments. --- mava/systems/tf/madqn/execution.py | 3 ++- mava/systems/tf/madqn/system.py | 3 --- mava/systems/tf/madqn/training.py | 17 ++++------------- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 484b09d16..946ebd629 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -394,7 +394,8 @@ def _policy( @tf.function def _select_actions( - self, observations: Dict[str, types.NestedArray], states + self, observations: Dict[str, types.NestedArray], + states: Dict[str, types.NestedArray] ) -> types.NestedArray: actions: Dict = {} new_states: Dict = {} diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 92d52d6af..49a587a17 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -53,9 +53,6 @@ class MADQN: """MADQN system.""" - """TODO: Implement faster adders to speed up training times when - using multiple trainers with non-shared weights.""" - def __init__( # noqa self, environment_factory: Callable[[bool], dm_env.Environment], diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 24cb00008..92ac23d94 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -234,6 +234,7 @@ def _transform_observations( o_t[agent] = tree.map_structure(tf.stop_gradient, o_t[agent]) return o_tm1, o_t + @tf.function def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -535,7 +536,7 @@ def _transform_observations( obs_target_trans: transformed target network observations """ - # Note (dries): We are assuming that only the policy network + # NOTE We are assuming that only the value network # is recurrent and not the observation network. obs_trans = {} obs_target_trans = {} @@ -555,10 +556,8 @@ def _transform_observations( dims, ) - # This stop_gradient prevents gradients to propagate into the target - # observation network. In addition, since the online policy network is - # evaluated at o_t, this also means the policy loss does not influence - # the observation network training. + # This stop_gradient prevents gradients to propagate into + # the target observation network. obs_target_trans[agent] = tree.map_structure( tf.stop_gradient, obs_target_trans[agent] ) @@ -597,9 +596,6 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] """Depricated""" pass - # NOTE (Claude) The recurrent trainer does not start with tf.function - # It does start on SMAC 3m and debug env but not on any other SMAC maps. - # TODO (Claude) get tf.function to work. @tf.function def _step( self, @@ -623,7 +619,6 @@ def _step( # Log losses per agent return train_utils.map_losses_per_agent_value(self.value_losses) - # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass. @@ -741,10 +736,6 @@ def _backward(self) -> None: ) # Compute gradients. - # Note: Warning "WARNING:tensorflow:Calling GradientTape.gradient - # on a persistent tape inside its context is significantly less efficient - # than calling it outside the context." caused by losses.dpg, which calls - # tape.gradient. gradients = tape.gradient(value_losses[agent], variables) # Maybe clip gradients. From 1a8a5b3dbea99c508985242178924e9b72e1d9cc Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 11:04:18 +0200 Subject: [PATCH 36/56] Fix docstring in examples. --- .../feedforward/decentralised/run_madqn.py | 4 +-- .../flatland/recurrent/centralised/run_vdn.py | 10 ++++--- .../recurrent/decentralised/run_madqn.py | 11 ++++---- .../feedforward/decentralised/run_madqn.py | 6 ++--- .../pong/recurrent/decentralised/run_madqn.py | 12 ++++----- .../feedforward/decentralised/run_madqn.py | 26 ++++++++++--------- .../smac/recurrent/centralised/run_qmix.py | 22 +++++++++------- .../smac/recurrent/centralised/run_vdn.py | 23 +++++++++------- .../smac/recurrent/decentralised/run_madqn.py | 22 +++++++++------- .../decentralised/run_madqn_scale_trainers.py | 13 +++++----- 10 files changed, 79 insertions(+), 70 deletions(-) diff --git a/examples/flatland/feedforward/decentralised/run_madqn.py b/examples/flatland/feedforward/decentralised/run_madqn.py index 6bc4cc710..d2eacfeb2 100644 --- a/examples/flatland/feedforward/decentralised/run_madqn.py +++ b/examples/flatland/feedforward/decentralised/run_madqn.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Example running feedforward MADQN on Flatland""" +"""Example running feedforward MADQN on Flatland.""" import functools from datetime import datetime @@ -79,7 +79,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, diff --git a/examples/flatland/recurrent/centralised/run_vdn.py b/examples/flatland/recurrent/centralised/run_vdn.py index 685579f6d..acf4a39a2 100644 --- a/examples/flatland/recurrent/centralised/run_vdn.py +++ b/examples/flatland/recurrent/centralised/run_vdn.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Example running VDN on Flatland.""" import functools from datetime import datetime @@ -29,7 +30,6 @@ from mava.utils.environments.flatland_utils import make_environment from mava.utils.loggers import logger_utils -"""Example running VDN on Flatland.""" FLAGS = flags.FLAGS @@ -98,15 +98,17 @@ def main(_: Any) -> None: batch_size=32, max_gradient_norm=20.0, min_replay_size=32, - max_replay_size=10000, - samples_per_insert=16, + max_replay_size=5000, + samples_per_insert=4, evaluator_interval={"executor_episodes": 2}, ).build() - # launch + # Ensure only trainer runs on gpu, while other processes run on cpu. local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) + + # Launch lp.launch( program, lp.LaunchType.LOCAL_MULTI_PROCESSING, diff --git a/examples/flatland/recurrent/decentralised/run_madqn.py b/examples/flatland/recurrent/decentralised/run_madqn.py index b3285721b..0d1824a11 100644 --- a/examples/flatland/recurrent/decentralised/run_madqn.py +++ b/examples/flatland/recurrent/decentralised/run_madqn.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Example running recurrent MADQN on Flatland.""" import functools from datetime import datetime @@ -59,7 +59,7 @@ def main(_: Any) -> None: - """Run example""" + """Run example.""" # Environment. environment_factory = functools.partial(make_environment, **env_config) @@ -83,7 +83,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, @@ -92,12 +92,11 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=1e-5 ), - optimizer=snt.optimizers.Adam(learning_rate=1e-4), batch_size=32, - samples_per_insert=16, + samples_per_insert=4, max_gradient_norm=20.0, min_replay_size=32, - max_replay_size=10000, + max_replay_size=5000, trainer_fn=madqn.MADQNRecurrentTrainer, executor_fn=madqn.MADQNRecurrentExecutor, checkpoint_subpath=checkpoint_dir, diff --git a/examples/openspiel/tic_tac_toe/feedforward/decentralised/run_madqn.py b/examples/openspiel/tic_tac_toe/feedforward/decentralised/run_madqn.py index a9a77fdb8..ad3467efe 100644 --- a/examples/openspiel/tic_tac_toe/feedforward/decentralised/run_madqn.py +++ b/examples/openspiel/tic_tac_toe/feedforward/decentralised/run_madqn.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Example running feedforward MADQN on OpenSpiel's tic_tac_toe.""" -"""Example running MADQN on OpenSpiel's tic_tac_toe.""" import functools from datetime import datetime from typing import Any @@ -58,7 +58,7 @@ def make_environment( def main(_: Any) -> None: - # environment + # Environment environment_factory = functools.partial( make_environment, game=FLAGS.game, @@ -81,7 +81,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, diff --git a/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py b/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py index a7ac7312a..0bab836e3 100644 --- a/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py +++ b/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py @@ -12,8 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Example running MADQN on Atari Pong.""" +"""Example running recurrent MADQN on Atari Pong.""" import functools from datetime import datetime @@ -51,7 +50,7 @@ def main(_: Any) -> None: - # environment + # Environment environment_factory = functools.partial( pettingzoo_utils.make_environment, env_class=FLAGS.env_class, @@ -79,7 +78,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, @@ -90,10 +89,9 @@ def main(_: Any) -> None: ), shared_weights=False, batch_size=32, - max_replay_size=10000, - samples_per_insert=16, + max_replay_size=5000, + samples_per_insert=4, min_replay_size=32, - optimizer=snt.optimizers.Adam(learning_rate=1e-4), executor_fn=madqn.MADQNRecurrentExecutor, trainer_fn=madqn.MADQNRecurrentTrainer, evaluator_interval={"executor_episodes": 2}, diff --git a/examples/smac/feedforward/decentralised/run_madqn.py b/examples/smac/feedforward/decentralised/run_madqn.py index 5f869514d..07e6e37db 100644 --- a/examples/smac/feedforward/decentralised/run_madqn.py +++ b/examples/smac/feedforward/decentralised/run_madqn.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Run feedforward MADQN on SMAC.""" import functools @@ -23,11 +24,11 @@ from absl import app, flags from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationTimestepScheduler, + LinearExplorationScheduler, ) from mava.systems.tf import madqn from mava.utils import lp_utils -from mava.utils.environments import pettingzoo_utils +from mava.utils.environments.smac_utils import make_environment from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS @@ -46,11 +47,10 @@ def main(_: Any) -> None: - """Example running MADQN on multi-agent Starcraft 2 (SMAC) environment.""" - # environment - environment_factory = functools.partial( - pettingzoo_utils.make_environment, env_class="smac", env_name=FLAGS.map_name - ) + """Example running feedforward MADQN on SMAC environment.""" + + # Environment + environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. network_factory = lp_utils.partial_kwargs( @@ -71,29 +71,31 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationTimestepScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 + exploration_scheduler_fn=LinearExplorationScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 ), checkpoint_subpath=checkpoint_dir, batch_size=256, - executor_variable_update_period=1000, + executor_variable_update_period=200, target_update_period=200, max_gradient_norm=20.0, ).build() - # launch + # Only the trainer should use the GPU (if available) local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) + + # Launch lp.launch( program, lp.LaunchType.LOCAL_MULTI_PROCESSING, diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index 767741034..124184023 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Example running QMIX on SMAC""" import functools @@ -33,7 +34,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "1c3s5z", + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -47,8 +48,9 @@ def main(_: Any) -> None: - """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" - # environment + """Example running recurrent QMIX on SMAC environment.""" + + # Environment environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. @@ -70,7 +72,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = value_decomposition.ValueDecomposition( environment_factory=environment_factory, network_factory=network_factory, @@ -78,7 +80,7 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=4e-6 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 @@ -90,16 +92,18 @@ def main(_: Any) -> None: max_gradient_norm=20.0, min_replay_size=32, max_replay_size=5000, - samples_per_insert=16, - sequence_length=200, - period=200, + samples_per_insert=4, + sequence_length=20, + period=10, evaluator_interval={"executor_episodes": 2}, ).build() - # launch + # Only the trainer should use the GPU (if available) local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) + + # Launch lp.launch( program, lp.LaunchType.LOCAL_MULTI_PROCESSING, diff --git a/examples/smac/recurrent/centralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py index ae0df4e31..f79a5700d 100644 --- a/examples/smac/recurrent/centralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +"""Example running VDN on SMAC""" import functools from datetime import datetime @@ -33,7 +33,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( "map_name", - "1c3s5z", + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -47,8 +47,9 @@ def main(_: Any) -> None: - """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" - # environment + """Example running recurrent VDN on SMAC environment.""" + + # Environment environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. @@ -70,7 +71,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = value_decomposition.ValueDecomposition( environment_factory=environment_factory, network_factory=network_factory, @@ -87,17 +88,19 @@ def main(_: Any) -> None: batch_size=32, max_gradient_norm=20.0, min_replay_size=32, - max_replay_size=10000, - samples_per_insert=16, - sequence_length=200, - period=200, + max_replay_size=5000, + samples_per_insert=4, + sequence_length=20, + period=10, evaluator_interval={"executor_episodes": 2}, ).build() - # launch + # Only the trainer should use the GPU (if available) local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) + + # Launch lp.launch( program, lp.LaunchType.LOCAL_MULTI_PROCESSING, diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index b907dafe9..d57d21610 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Run example MADQN.""" +"""Example running recurent MADQN on SMAC.""" + import functools from datetime import datetime @@ -32,10 +33,9 @@ from mava.utils.loggers import logger_utils FLAGS = flags.FLAGS - flags.DEFINE_string( "map_name", - "1c3s5z", + "3m", "Starcraft 2 micromanagement map name (str).", ) @@ -44,13 +44,13 @@ str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) - flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") def main(_: Any) -> None: - """Example running recurrent MADQN on multi-agent Starcraft 2 (SMAC) environment.""" - # environment + """Example running recurrent MADQN on SMAC environment.""" + + # Environment environment_factory = functools.partial(make_environment, map_name=FLAGS.map_name) # Networks. @@ -91,18 +91,20 @@ def main(_: Any) -> None: max_gradient_norm=20.0, min_replay_size=32, max_replay_size=5000, - samples_per_insert=16, - sequence_length=200, - period=200, + samples_per_insert=4, + sequence_length=20, + period=10, evaluator_interval={"executor_episodes": 2}, trainer_fn=madqn.MADQNRecurrentTrainer, executor_fn=madqn.MADQNRecurrentExecutor, ).build() - # launch + # Only the trainer should use the GPU (if available) local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) + + # Launch lp.launch( program, lp.LaunchType.LOCAL_MULTI_PROCESSING, diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index c4b52a385..a207f76f3 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - """Example running MADQN on SMAC with multiple trainers.""" + import functools from datetime import datetime from typing import Any @@ -47,7 +46,7 @@ def main(_: Any) -> None: - """Main function.""" + """Run MADQN on SMAC with multiple trainers.""" # Environment. environment_factory = functools.partial( @@ -74,7 +73,7 @@ def main(_: Any) -> None: time_delta=log_every, ) - # distributed program + # Distributed program program = madqn.MADQN( environment_factory=environment_factory, network_factory=network_factory, @@ -83,7 +82,7 @@ def main(_: Any) -> None: exploration_scheduler_fn=LinearExplorationScheduler( epsilon_start=1.0, epsilon_min=0.05, - epsilon_decay=4e-5, + epsilon_decay=5e-6, ), shared_weights=False, trainer_networks=enums.Trainer.one_trainer_per_network, @@ -93,12 +92,12 @@ def main(_: Any) -> None: max_replay_size=5000, min_replay_size=32, batch_size=32, + samples_per_insert=4, evaluator_interval={"executor_episodes": 2}, - optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, ).build() - # Ensure only trainer runs on gpu, while other processes run on cpu. + # Only the trainer should use the GPU (if available) local_resources = lp_utils.to_device( program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] ) From 26016ce3826acca5f04ce15f8fa5c7110c7c3bc9 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 11:17:01 +0200 Subject: [PATCH 37/56] Fix docstrings in MADQN system. --- mava/systems/tf/madqn/execution.py | 27 +++++++++++++-------------- mava/systems/tf/madqn/system.py | 18 ++++++------------ mava/systems/tf/madqn/training.py | 20 ++++++++++++-------- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 946ebd629..9a006c546 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """MADQN system executor implementation.""" from typing import Any, Dict, List, Optional, Tuple, Union @@ -23,8 +24,6 @@ import tree from acme import types from acme.specs import EnvironmentSpec - -# Internal imports. from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils from dm_env import specs @@ -102,9 +101,10 @@ def get_stats(self) -> Dict: class MADQNFeedForwardExecutor(executors.FeedForwardExecutor, DQNExecutor): - """A feed-forward executor for discrete actions. + """A feed-forward executor for MADQN like systems. - An executor based on a feed-forward policy for each agent in the system. + An executor based on a feed-forward epsilon-greedy policy for + each agent in the system. """ def __init__( @@ -166,15 +166,13 @@ def _policy( observation: types.NestedTensor, legal_actions: types.NestedTensor, ) -> types.NestedTensor: - """Agent specific policy function + """Epsilon greedy policy. Args: agent: agent id observation: observation tensor received from the environment. - - Raises: - NotImplementedError: unknown action space + legal_actions: one-hot vector of legal actions. Returns: types.NestedTensor: agent action @@ -202,6 +200,7 @@ def _policy( def _select_actions( self, observations: Dict[str, types.NestedArray] ) -> types.NestedArray: + """The part of select_actions we can do in tf.function""" actions = {} for agent, observation in observations.items(): actions[agent] = self._policy( @@ -287,9 +286,10 @@ def update(self, wait: bool = False) -> None: class MADQNRecurrentExecutor(executors.RecurrentExecutor, DQNExecutor): - """A recurrent executor for MADQN. + """A recurrent executor for MADQN like systems. - An executor based on a recurrent policy for each agent in the system. + An executor based on a recurrent epsilon-greedy policy + for each agent in the system. """ def __init__( @@ -359,17 +359,15 @@ def _policy( legal_actions: types.NestedTensor, state: types.NestedTensor, ) -> Tuple: - """Agent specific policy function. + """Agent epsilon-greedy policy. Args: agent: agent id observation: observation tensor received from the environment. + legal_actions: one-hot vector of legal actions state: recurrent network state. - Raises: - NotImplementedError: unknown action space - Returns: action, policy and new recurrent hidden state """ @@ -397,6 +395,7 @@ def _select_actions( self, observations: Dict[str, types.NestedArray], states: Dict[str, types.NestedArray] ) -> types.NestedArray: + """The part of select_action that we can do inside tf.function""" actions: Dict = {} new_states: Dict = {} for agent, observation in observations.items(): diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 49a587a17..080e52579 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -14,7 +14,6 @@ # limitations under the License. """MADQN system implementation.""" - import functools from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union @@ -104,7 +103,7 @@ def __init__( # noqa evaluator_interval: Optional[dict] = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise the system + """Initialise the system. Args: environment_factory: function to @@ -187,13 +186,8 @@ def __init__( # noqa happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - learning_rate_scheduler_fn: dict with two functions/classes (one for the - policy and one for the critic optimizer), that takes in a trainer - step t and returns the current learning rate, - e.g. {"policy": policy_lr_schedule ,"critic": critic_lr_schedule}. - See - examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py - for an example. + learning_rate_scheduler_fn: an optional learning rate scheduler for + the value function optimiser. """ if not environment_spec: @@ -413,7 +407,7 @@ def __init__( # noqa ) def _get_extra_specs(self) -> Any: - """Helper to establish specs for extra information + """Helper to establish specs for extra information. Returns: Dictionary containing extra specs @@ -435,7 +429,7 @@ def _get_extra_specs(self) -> Any: return {"core_states": core_state_specs, "zero_padding_mask": np.array(1)} def replay(self) -> Any: - """Step counter + """Step counter. Args: checkpoint: whether to checkpoint the counter. @@ -590,7 +584,7 @@ def trainer( replay: reverb.Client, variable_source: MavaVariableSource, ) -> mava.core.Trainer: - """System trainer + """System trainer. Args: trainer_id: Id of the trainer being created. diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 92ac23d94..595de2e88 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -15,7 +15,6 @@ """MADQN trainer implementation.""" - import copy from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -40,7 +39,7 @@ class MADQNTrainer(mava.Trainer): """MADQN trainer. - This is the trainer component of a MADDPG system. IE it takes a dataset as input + This is the trainer component of a MADQN system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -210,7 +209,9 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] def _transform_observations( self, obs: Dict[str, mava_types.OLT], next_obs: Dict[str, mava_types.OLT] ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: - """Transform the observations using the observation networks of each agent." + """Transform the observations using the observation networks of each agent. + + We assume the observation network is non-recurrent. Args: obs: observations at timestep t-1 @@ -238,7 +239,7 @@ def _transform_observations( def _step( self, ) -> Dict[str, Dict[str, Any]]: - """Trainer forward and backward passes. + """Trainer step. Returns: losses @@ -247,8 +248,10 @@ def _step( # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) + # Compute loss self._forward(sample) + # Compute and apply gradients self._backward() # Update the target networks @@ -257,7 +260,6 @@ def _step( # Log losses per agent return train_utils.map_losses_per_agent_value(self.value_losses) - # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass. @@ -322,7 +324,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: self.tape = tape - # Backward pass that calculates gradients and updates network. def _backward(self) -> None: """Trainer backward pass updating network parameters""" @@ -528,6 +529,8 @@ def _transform_observations( ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]: """Apply the observation networks to the raw observations from the dataset + We assume that the observation network is non-recurrent. + Args: observations: raw agent observations @@ -600,7 +603,7 @@ def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray] def _step( self, ) -> Dict[str, Dict[str, Any]]: - """Trainer forward and backward passes. + """Trainer step. Returns: losses @@ -609,8 +612,10 @@ def _step( # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) + # Compute loss self._forward(sample) + # Compute and apply gradients self._backward() # Update the target networks @@ -719,7 +724,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: self.tape = tape - # Backward pass that calculates gradients and updates network. def _backward(self) -> None: """Trainer backward pass updating network parameters""" From 55e7694143ae258c75ca9cb071c22548403cb601 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 11:25:36 +0200 Subject: [PATCH 38/56] Fix docstrings in Value Decomposition system. --- mava/systems/tf/madqn/networks.py | 2 ++ mava/systems/tf/value_decomposition/networks.py | 2 ++ mava/systems/tf/value_decomposition/system.py | 12 ++++++++++-- mava/systems/tf/value_decomposition/training.py | 15 +++++++-------- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index 35c8698e9..ec4615acf 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Default networks for MADQN systems.""" from typing import Dict, Mapping, Optional, Sequence, Union import sonnet as snt diff --git a/mava/systems/tf/value_decomposition/networks.py b/mava/systems/tf/value_decomposition/networks.py index d2f11602f..f59414fa0 100644 --- a/mava/systems/tf/value_decomposition/networks.py +++ b/mava/systems/tf/value_decomposition/networks.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Default networks for Value Decomposition systems""" from typing import Dict, Mapping, Optional, Sequence, Union import sonnet as snt diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index feb102669..822216612 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Value Decomposition system implementation.""" +"""Value Decomposition system implementation.""" from typing import Callable, Dict, Mapping, Optional, Type, Union import dm_env @@ -38,7 +38,11 @@ class ValueDecomposition(MADQN): - """Value Decomposition systems.""" + """Value Decomposition systems. + + + Inherits from recurrent MADQN. + """ def __init__( self, @@ -208,6 +212,9 @@ def __init__( learning_rate_scheduler_fn=learning_rate_scheduler_fn, ) + # NOTE Users can either pass in their own mixer or + # use one of the pre-built ones by passing in a + # string "qmix" or "vdn". if isinstance(mixer, str): if mixer == "qmix": env = environment_factory() # type: ignore @@ -264,6 +271,7 @@ def trainer( variable_source=variable_source, ) + # Setup the mixer trainer.setup_mixer(self._mixer, self._mixer_optimizer) # type: ignore return trainer diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index ee7e15ede..eb9b9a647 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Value Decomposition trainer implementation.""" +"""Value Decomposition trainer implementation.""" import copy from typing import Any, Callable, Dict, List, Optional, Union @@ -130,7 +130,7 @@ def setup_mixer(self, mixer: snt.Module, mixer_optimizer: snt.Module) -> None: self._mixer_optimizer = mixer_optimizer def _update_target_networks(self) -> None: - """Update the target networks. + """Update the target networks and the target mixer. Using either target averaging or by directy copying the weights of the online networks every few steps. @@ -170,7 +170,6 @@ def _update_target_networks(self) -> None: self._num_steps.assign_add(1) - # Forward pass that calculates loss. def _forward(self, inputs: reverb.ReplaySample) -> None: """Trainer forward pass. @@ -208,8 +207,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: - # NOTE (Dries): We are assuming that only the value network - # is recurrent and not the observation network. + obs_trans, target_obs_trans = self._transform_observations(observations) # Lists for stacking tensors later @@ -230,6 +228,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: chosen_action_q_value = trfl.batched_index(q_tm1_values, actions[agent]) # Q-value of the next state + # Legal action masking q_t_selector = tf.where( tf.cast(observations[agent].legal_actions, "bool"), q_tm1_values, @@ -269,12 +268,13 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: max_action_q_value_all_agents, states=global_env_state ) - # NOTE Team reward is just the mean over agents indevidual rewards + # NOTE Weassume team reward is just the mean + # over agents indevidual rewards reward_all_agents = tf.reduce_mean( reward_all_agents, axis=-1, keepdims=True ) # NOTE We assume all agents have the same env discount since - # it is a team game + # it is a team game. env_discount_all_agents = tf.reduce_mean( env_discount_all_agents, axis=-1, keepdims=True ) @@ -307,7 +307,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: self.tape = tape - # Backward pass that calculates gradients and updates network. def _backward(self) -> None: """Trainer backward pass updating network parameters""" From 219a277b76b968b56540b80c49edb6cd8584a11b Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 11:31:55 +0200 Subject: [PATCH 39/56] Fix docstrings in wrappers and utils. --- mava/utils/environments/flatland_utils.py | 1 + mava/utils/environments/smac_utils.py | 2 ++ mava/wrappers/env_preprocess_wrappers.py | 6 ++++-- mava/wrappers/smac.py | 8 ++------ 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mava/utils/environments/flatland_utils.py b/mava/utils/environments/flatland_utils.py index 50ca59b5f..e0bf144f8 100644 --- a/mava/utils/environments/flatland_utils.py +++ b/mava/utils/environments/flatland_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Utils for making Flatland environment.""" from typing import Optional from mava.wrappers.env_preprocess_wrappers import ( diff --git a/mava/utils/environments/smac_utils.py b/mava/utils/environments/smac_utils.py index a1d6a4f31..afad2f6e4 100644 --- a/mava/utils/environments/smac_utils.py +++ b/mava/utils/environments/smac_utils.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Utils for SMAC environment.""" from typing import Any, Optional from smac.env import StarCraft2Env diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index 631f2a32c..5c4cb45af 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -375,6 +375,7 @@ def __init__(self, environment: Any) -> None: self._num_agents = len(environment.possible_agents) def reset(self) -> dm_env.TimeStep: + """Reset environment and concat agent ID.""" timestep, extras = self._environment.reset() old_observations = timestep.observation @@ -401,6 +402,7 @@ def reset(self) -> dm_env.TimeStep: ) def step(self, actions: Dict) -> dm_env.TimeStep: + """Step the environment and concat agent ID""" timestep, extras = self._environment.step(actions) old_observations = timestep.observation @@ -458,12 +460,11 @@ class ConcatPrevActionToObservation: TODO (Claude) support continuous actions. """ - # Need to get the size of the action space of each agent - def __init__(self, environment: Any): self._environment = environment def reset(self) -> dm_env.TimeStep: + """Reset the environment and add zero action.""" timestep, extras = self._environment.reset() old_observations = timestep.observation action_spec = self._environment.action_spec() @@ -490,6 +491,7 @@ def reset(self) -> dm_env.TimeStep: ) def step(self, actions: Dict) -> dm_env.TimeStep: + """Step the environment and concat prev actions.""" timestep, extras = self._environment.step(actions) old_observations = timestep.observation action_spec = self._environment.action_spec() diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index 46429d64f..3bca332f6 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -171,6 +171,7 @@ def _convert_reward(self, reward: float) -> Dict[str, float]: return rewards def _get_legal_actions(self) -> List: + """Get legal actions from the environment.""" legal_actions = [] for i, _ in enumerate(self._agents): legal_actions.append( @@ -181,7 +182,7 @@ def _get_legal_actions(self) -> List: def _convert_observations( self, observations: List, legal_actions: List, done: bool ) -> types.Observation: - """Convert PettingZoo observation so it's dm_env compatible. + """Convert SMAC observation so it's dm_env compatible. Args: observes (Dict[str, np.ndarray]): observations per agent. @@ -319,8 +320,3 @@ def __getattr__(self, name: str) -> Any: return self.__getattribute__(name) else: return getattr(self._environment, name) - - -env = StarCraft2Env(map_name="3m") - -wrapped_env = SMACWrapper(env) From 9adb7d08fa73359830f6340239fe8414aa0ab498 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 11:37:24 +0200 Subject: [PATCH 40/56] Add Value Decomposition README. --- mava/systems/tf/value_decomposition/README.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 mava/systems/tf/value_decomposition/README.md diff --git a/mava/systems/tf/value_decomposition/README.md b/mava/systems/tf/value_decomposition/README.md new file mode 100644 index 000000000..87df69eb9 --- /dev/null +++ b/mava/systems/tf/value_decomposition/README.md @@ -0,0 +1,7 @@ +# Value Decomposition Methods eg. VDN and QMIX + +This system supports to important Value Decomposition methods, VDN and QMIX. +The design of the system also allows the user to easily include their own mixer in place of the two supported ones. + +[Sunehag et al., 2017]: https://arxiv.org/abs/1706.05296 +[Rashid et al., 2018]: https://arxiv.org/abs/1803.11485 From f66be7f15f78df17f049b7f159b49a7501f40efd Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 12:08:41 +0200 Subject: [PATCH 41/56] Fix imports when users have not installed SMAC or Flatland. --- mava/utils/environments/flatland_utils.py | 222 ++-- mava/utils/environments/smac_utils.py | 36 +- mava/wrappers/__init__.py | 9 +- mava/wrappers/flatland.py | 1124 +++++++++++---------- mava/wrappers/smac.py | 566 ++++++----- 5 files changed, 992 insertions(+), 965 deletions(-) diff --git a/mava/utils/environments/flatland_utils.py b/mava/utils/environments/flatland_utils.py index e0bf144f8..e26cbe788 100644 --- a/mava/utils/environments/flatland_utils.py +++ b/mava/utils/environments/flatland_utils.py @@ -33,119 +33,121 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator -except ModuleNotFoundError: - pass - - -def _create_rail_env_with_tree_obs( - n_agents: int = 5, - x_dim: int = 30, - y_dim: int = 30, - n_cities: int = 2, - max_rails_between_cities: int = 2, - max_rails_in_city: int = 3, - seed: Optional[int] = 0, - malfunction_rate: float = 1 / 200, - malfunction_min_duration: int = 20, - malfunction_max_duration: int = 50, - observation_max_path_depth: int = 30, - observation_tree_depth: int = 2, -) -> RailEnv: - """Create a Flatland RailEnv with TreeObservation. - - Args: - n_agents: Number of trains. Defaults to 5. - x_dim: Width of map. Defaults to 30. - y_dim: Height of map. Defaults to 30. - n_cities: Number of cities. Defaults to 2. - max_rails_between_cities: Max rails between cities. Defaults to 2. - max_rails_in_city: Max rails in cities. Defaults to 3. - seed: Random seed. Defaults to 0. - malfunction_rate: Malfunction rate. Defaults to 1/200. - malfunction_min_duration: Min malfunction duration. Defaults to 20. - malfunction_max_duration: Max malfunction duration. Defaults to 50. - observation_max_path_depth: Shortest path predictor depth. Defaults to 30. - observation_tree_depth: TreeObs depth. Defaults to 2. - - Returns: - RailEnv: A Flatland RailEnv. - """ - - # Break agents from time to time - malfunction_parameters = MalfunctionParameters( - malfunction_rate=malfunction_rate, - min_duration=malfunction_min_duration, - max_duration=malfunction_max_duration, - ) - - # Observation builder - predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) - tree_observation = TreeObsForRailEnv( - max_depth=observation_tree_depth, predictor=predictor - ) + _found_flatland = True - rail_env = RailEnv( - width=x_dim, - height=y_dim, - rail_generator=sparse_rail_generator( - max_num_cities=n_cities, - grid_mode=False, +except ModuleNotFoundError: + _found_flatland = False + +if _found_flatland: + + def _create_rail_env_with_tree_obs( + n_agents: int = 5, + x_dim: int = 30, + y_dim: int = 30, + n_cities: int = 2, + max_rails_between_cities: int = 2, + max_rails_in_city: int = 3, + seed: Optional[int] = 0, + malfunction_rate: float = 1 / 200, + malfunction_min_duration: int = 20, + malfunction_max_duration: int = 50, + observation_max_path_depth: int = 30, + observation_tree_depth: int = 2, + ) -> RailEnv: + """Create a Flatland RailEnv with TreeObservation. + + Args: + n_agents: Number of trains. Defaults to 5. + x_dim: Width of map. Defaults to 30. + y_dim: Height of map. Defaults to 30. + n_cities: Number of cities. Defaults to 2. + max_rails_between_cities: Max rails between cities. Defaults to 2. + max_rails_in_city: Max rails in cities. Defaults to 3. + seed: Random seed. Defaults to 0. + malfunction_rate: Malfunction rate. Defaults to 1/200. + malfunction_min_duration: Min malfunction duration. Defaults to 20. + malfunction_max_duration: Max malfunction duration. Defaults to 50. + observation_max_path_depth: Shortest path predictor depth. Defaults to 30. + observation_tree_depth: TreeObs depth. Defaults to 2. + + Returns: + RailEnv: A Flatland RailEnv. + """ + + # Break agents from time to time + malfunction_parameters = MalfunctionParameters( + malfunction_rate=malfunction_rate, + min_duration=malfunction_min_duration, + max_duration=malfunction_max_duration, + ) + + # Observation builder + predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) + tree_observation = TreeObsForRailEnv( + max_depth=observation_tree_depth, predictor=predictor + ) + + rail_env = RailEnv( + width=x_dim, + height=y_dim, + rail_generator=sparse_rail_generator( + max_num_cities=n_cities, + grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rail_pairs_in_city=max_rails_in_city // 2, + ), + line_generator=sparse_line_generator(), + number_of_agents=n_agents, + malfunction_generator=ParamMalfunctionGen(malfunction_parameters), + obs_builder_object=tree_observation, + random_seed=seed, + ) + + return rail_env + + def make_environment( + n_agents: int = 10, + x_dim: int = 30, + y_dim: int = 30, + n_cities: int = 2, + max_rails_between_cities: int = 2, + max_rails_in_city: int = 3, + seed: int = 0, + malfunction_rate: float = 1 / 200, + malfunction_min_duration: int = 20, + malfunction_max_duration: int = 50, + observation_max_path_depth: int = 30, + observation_tree_depth: int = 2, + concat_prev_actions: bool = True, + concat_agent_id: bool = False, + evaluation: bool = False, + random_seed: Optional[int] = None, + ) -> FlatlandEnvWrapper: + """Loads a flatand environment and wraps it using the flatland wrapper""" + + del evaluation # since it has same behaviour for both train and eval + + env = _create_rail_env_with_tree_obs( + n_agents=n_agents, + x_dim=x_dim, + y_dim=y_dim, + n_cities=n_cities, max_rails_between_cities=max_rails_between_cities, - max_rail_pairs_in_city=max_rails_in_city // 2, - ), - line_generator=sparse_line_generator(), - number_of_agents=n_agents, - malfunction_generator=ParamMalfunctionGen(malfunction_parameters), - obs_builder_object=tree_observation, - random_seed=seed, - ) - - return rail_env - - -def make_environment( - n_agents: int = 10, - x_dim: int = 30, - y_dim: int = 30, - n_cities: int = 2, - max_rails_between_cities: int = 2, - max_rails_in_city: int = 3, - seed: int = 0, - malfunction_rate: float = 1 / 200, - malfunction_min_duration: int = 20, - malfunction_max_duration: int = 50, - observation_max_path_depth: int = 30, - observation_tree_depth: int = 2, - concat_prev_actions: bool = True, - concat_agent_id: bool = False, - evaluation: bool = False, - random_seed: Optional[int] = None, -) -> FlatlandEnvWrapper: - """Loads a flatand environment and wraps it using the flatland wrapper""" - - del evaluation # since it has same behaviour for both train and eval - - env = _create_rail_env_with_tree_obs( - n_agents=n_agents, - x_dim=x_dim, - y_dim=y_dim, - n_cities=n_cities, - max_rails_between_cities=max_rails_between_cities, - max_rails_in_city=max_rails_in_city, - seed=random_seed, - malfunction_rate=malfunction_rate, - malfunction_min_duration=malfunction_min_duration, - malfunction_max_duration=malfunction_max_duration, - observation_max_path_depth=observation_max_path_depth, - observation_tree_depth=observation_tree_depth, - ) + max_rails_in_city=max_rails_in_city, + seed=random_seed, + malfunction_rate=malfunction_rate, + malfunction_min_duration=malfunction_min_duration, + malfunction_max_duration=malfunction_max_duration, + observation_max_path_depth=observation_max_path_depth, + observation_tree_depth=observation_tree_depth, + ) - env = FlatlandEnvWrapper(env) + env = FlatlandEnvWrapper(env) - if concat_prev_actions: - env = ConcatPrevActionToObservation(env) + if concat_prev_actions: + env = ConcatPrevActionToObservation(env) - if concat_agent_id: - env = ConcatAgentIdToObservation(env) + if concat_agent_id: + env = ConcatAgentIdToObservation(env) - return env + return env diff --git a/mava/utils/environments/smac_utils.py b/mava/utils/environments/smac_utils.py index afad2f6e4..18ecbf029 100644 --- a/mava/utils/environments/smac_utils.py +++ b/mava/utils/environments/smac_utils.py @@ -16,7 +16,12 @@ """Utils for SMAC environment.""" from typing import Any, Optional -from smac.env import StarCraft2Env +try: + from smac.env import StarCraft2Env + + _found_smac = True +except ModuleNotFoundError: + _found_smac = False from mava.wrappers import SMACWrapper from mava.wrappers.env_preprocess_wrappers import ( @@ -24,22 +29,23 @@ ConcatPrevActionToObservation, ) +if _found_smac: -def make_environment( - map_name: str = "3m", - concat_prev_actions: bool = True, - concat_agent_id: bool = True, - evaluation: bool = False, - random_seed: Optional[int] = None, -) -> Any: - env = StarCraft2Env(map_name=map_name, seed=random_seed) + def make_environment( + map_name: str = "3m", + concat_prev_actions: bool = True, + concat_agent_id: bool = True, + evaluation: bool = False, + random_seed: Optional[int] = None, + ) -> Any: + env = StarCraft2Env(map_name=map_name, seed=random_seed) - env = SMACWrapper(env) + env = SMACWrapper(env) - if concat_prev_actions: - env = ConcatPrevActionToObservation(env) + if concat_prev_actions: + env = ConcatPrevActionToObservation(env) - if concat_agent_id: - env = ConcatAgentIdToObservation(env) + if concat_agent_id: + env = ConcatAgentIdToObservation(env) - return env + return env diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index d90c841a1..ec8e13198 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -26,8 +26,13 @@ ) from mava.wrappers.robocup import RoboCupWrapper -# from mava.wrappers.flatland import FlatlandEnvWrapper -from mava.wrappers.smac import SMACWrapper +try: + # The user might not have installed Flatland or SMAC + from mava.wrappers.flatland import FlatlandEnvWrapper + from mava.wrappers.smac import SMACWrapper +except ModuleNotFoundError: + pass + from mava.wrappers.system_trainer_statistics import ( DetailedTrainerStatistics, NetworkStatisticsActorCritic, diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index 31a19a0fa..ddd462269 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -14,7 +14,6 @@ # limitations under the License. """Wraps a Flatland MARL environment to be used as a dm_env environment.""" - import types as tp from functools import partial from typing import Any, Callable, Dict, List, Tuple, Union @@ -23,10 +22,6 @@ import numpy as np from acme import specs from acme.wrappers.gym_wrapper import _convert_to_spec -from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv -from flatland.envs.rail_env import RailEnv -from flatland.envs.step_utils.states import TrainState -from flatland.utils.rendertools import AgentRenderVariant, RenderTool from gym.spaces import Discrete from gym.spaces.box import Box @@ -39,601 +34,612 @@ ) from mava.wrappers.env_wrappers import ParallelEnvWrapper - -class FlatlandEnvWrapper(ParallelEnvWrapper): - """Environment wrapper for Flatland environments. - - All environments would require an observation preprocessor, except for - 'GlobalObsForRailEnv'. This is because flatland gives users the - flexibility of designing custom observation builders. 'TreeObsForRailEnv' - would use the normalize_observation function from the flatland baselines - if none is supplied. - The supplied preprocessor should return either an array, tuple of arrays or - a dictionary of arrays for an observation input. - The obervation, for an agent, returned by this wrapper could consist of both - the agent observation and agent info. This is because flatland also provides - informationn about the agents at each step. This information include; - 'action_required', 'malfunction', 'speed', and 'status', and it can be appended - to the observation, by this wrapper, as an array. action_required is a boolean, - malfunction is an int denoting the number of steps for which the agent would - remain motionless, speed is a float and status can be any of the below; - READY_TO_DEPART = 0 - ACTIVE = 1 - DONE = 2 - DONE_REMOVED = 3 - This would be included in the observation if agent_info is set to True - """ - - # Note: we don't inherit from base.EnvironmentWrapper because that class - # assumes that the wrapped environment is a dm_env.Environment. - def __init__( - self, - environment: RailEnv, - preprocessor: Callable[ - [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] - ] = None, - agent_info: bool = False, - ): - """Wrap Flatland environment. - - Args: - environment: underlying RailEnv - preprocessor: optional preprocessor. Defaults to None. - agent_info: include agent info. Defaults to True. +try: # noqa + from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv + from flatland.envs.rail_env import RailEnv + from flatland.envs.step_utils.states import TrainState + from flatland.utils.rendertools import AgentRenderVariant, RenderTool + + class FlatlandEnvWrapper(ParallelEnvWrapper): + """Environment wrapper for Flatland environments. + + All environments would require an observation preprocessor, except for + 'GlobalObsForRailEnv'. This is because flatland gives users the + flexibility of designing custom observation builders. 'TreeObsForRailEnv' + would use the normalize_observation function from the flatland baselines + if none is supplied. + The supplied preprocessor should return either an array, tuple of arrays or + a dictionary of arrays for an observation input. + The obervation, for an agent, returned by this wrapper could consist of both + the agent observation and agent info. This is because flatland also provides + informationn about the agents at each step. This information include; + 'action_required', 'malfunction', 'speed', and 'status', and it can be appended + to the observation, by this wrapper, as an array. action_required is a boolean, + malfunction is an int denoting the number of steps for which the agent would + remain motionless, speed is a float and status can be any of the below; + READY_TO_DEPART = 0 + ACTIVE = 1 + DONE = 2 + DONE_REMOVED = 3 + This would be included in the observation if agent_info is set to True """ - self._environment = environment - decorate_step_method(self._environment) - - self._agents = [get_agent_id(i) for i in range(self.num_agents)] - self._possible_agents = self.agents[:] - - self._reset_next_step = True - self._step_type = dm_env.StepType.FIRST - self.num_actions = 5 - - self.action_spaces = { - agent: Discrete(self.num_actions) for agent in self.possible_agents - } - - # preprocessor must be for observation builders other than global obs - # treeobs builders would use the default preprocessor if none is - # supplied - self.preprocessor: Callable[ - [Dict[int, Any]], Dict[int, Any] - ] = self._obtain_preprocessor(preprocessor) - - self._include_agent_info = agent_info - - # observation space: - # flatland defines no observation space for an agent. Here we try - # to define the observation space. All agents are identical and would - # have the same observation space. - # Infer observation space based on returned observation - obs, _ = self._environment.reset() - obs = self.preprocessor(obs) - self.observation_spaces = { - get_agent_id(i): infer_observation_space(ob) for i, ob in obs.items() - } - - self._env_renderer = RenderTool( - self._environment, - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, - screen_height=800, # Adjust these parameters to fit your resolution - screen_width=800, - ) # Adjust these parameters to fit your resolution - - @property - def agents(self) -> List[str]: - """Return list of active agents.""" - return self._agents - - @property - def possible_agents(self) -> List[str]: - """Return list of all possible agents.""" - return self._possible_agents - - def _update_stats(self, info: Dict, rewards: Dict) -> None: - """Update flatland stats.""" - episode_return = sum(list(rewards.values())) - tasks_finished = sum( - [1 if state == TrainState.DONE else 0 for state in info["state"].values()] - ) - completion = tasks_finished / len(self._agents) - normalized_score = episode_return / ( - self._environment._max_episode_steps * len(self._agents) - ) - self._latest_score = normalized_score - self._latest_completion = completion + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + def __init__( + self, + environment: RailEnv, + preprocessor: Callable[ + [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] + ] = None, + agent_info: bool = False, + ): + """Wrap Flatland environment. + + Args: + environment: underlying RailEnv + preprocessor: optional preprocessor. Defaults to None. + agent_info: include agent info. Defaults to True. + """ + self._environment = environment + decorate_step_method(self._environment) + + self._agents = [get_agent_id(i) for i in range(self.num_agents)] + self._possible_agents = self.agents[:] - def get_stats(self) -> Dict: - """Get flatland specific stats.""" - if self._latest_completion is not None and self._latest_score is not None: - return { - "score": self._latest_score, - "completion": self._latest_completion, + self._reset_next_step = True + self._step_type = dm_env.StepType.FIRST + self.num_actions = 5 + + self.action_spaces = { + agent: Discrete(self.num_actions) for agent in self.possible_agents } - else: - return {} - def render(self, mode: str = "human") -> np.ndarray: - """Renders the environment.""" - if mode == "human": - show = True - else: - show = False + # preprocessor must be for observation builders other than global obs + # treeobs builders would use the default preprocessor if none is + # supplied + self.preprocessor: Callable[ + [Dict[int, Any]], Dict[int, Any] + ] = self._obtain_preprocessor(preprocessor) + + self._include_agent_info = agent_info + + # observation space: + # flatland defines no observation space for an agent. Here we try + # to define the observation space. All agents are identical and would + # have the same observation space. + # Infer observation space based on returned observation + obs, _ = self._environment.reset() + obs = self.preprocessor(obs) + self.observation_spaces = { + get_agent_id(i): infer_observation_space(ob) for i, ob in obs.items() + } - return self._env_renderer.render_env( - show=show, - show_observations=False, - show_predictions=False, - return_image=True, - ) + self._env_renderer = RenderTool( + self._environment, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=800, # Adjust these parameters to fit your resolution + screen_width=800, + ) # Adjust these parameters to fit your resolution + + @property + def agents(self) -> List[str]: + """Return list of active agents.""" + return self._agents + + @property + def possible_agents(self) -> List[str]: + """Return list of all possible agents.""" + return self._possible_agents + + def _update_stats(self, info: Dict, rewards: Dict) -> None: + """Update flatland stats.""" + episode_return = sum(list(rewards.values())) + tasks_finished = sum( + [ + 1 if state == TrainState.DONE else 0 + for state in info["state"].values() + ] + ) + completion = tasks_finished / len(self._agents) + normalized_score = episode_return / ( + self._environment._max_episode_steps * len(self._agents) + ) + + self._latest_score = normalized_score + self._latest_completion = completion + + def get_stats(self) -> Dict: + """Get flatland specific stats.""" + if self._latest_completion is not None and self._latest_score is not None: + return { + "score": self._latest_score, + "completion": self._latest_completion, + } + else: + return {} + + def render(self, mode: str = "human") -> np.ndarray: + """Renders the environment.""" + if mode == "human": + show = True + else: + show = False + + return self._env_renderer.render_env( + show=show, + show_observations=False, + show_predictions=False, + return_image=True, + ) + + def env_done(self) -> bool: + """Checks if the environment is done.""" + return self._environment.dones["__all__"] or not self.agents + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + # Reset the rendering sytem + self._env_renderer.reset() - def env_done(self) -> bool: - """Checks if the environment is done.""" - return self._environment.dones["__all__"] or not self.agents - - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - # Reset the rendering sytem - self._env_renderer.reset() - - self._reset_next_step = False - self._agents = self.possible_agents[:] - self._discounts = { - agent: np.dtype("float32").type(1.0) for agent in self.agents - } - observe, info = self._environment.reset() - observations = self._create_observations(observe, info, self._environment.dones) - rewards_spec = self.reward_spec() - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) - for agent in self.possible_agents - } - - discount_spec = self.discount_spec() - discounts = { - agent: convert_np_type(discount_spec[agent].dtype, 1) - for agent in self.possible_agents - } - return parameterized_restart(rewards, discounts, observations), {} - - def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: - """Steps the environment.""" - self._pre_step() - - if self._reset_next_step: - return self.reset() - - self._agents = [ - agent - for agent in self.agents - if not self._environment.dones[get_agent_handle(agent)] - ] - - observations, rewards, dones, infos = self._environment.step(actions) - - rewards_spec = self.reward_spec() - # Handle empty rewards - if not rewards: + self._reset_next_step = False + self._agents = self.possible_agents[:] + self._discounts = { + agent: np.dtype("float32").type(1.0) for agent in self.agents + } + observe, info = self._environment.reset() + observations = self._create_observations( + observe, info, self._environment.dones + ) + rewards_spec = self.reward_spec() rewards = { agent: convert_np_type(rewards_spec[agent].dtype, 0) for agent in self.possible_agents } - else: - rewards = { - get_agent_id(agent): convert_np_type( - rewards_spec[get_agent_id(agent)].dtype, reward - ) - for agent, reward in rewards.items() - } - - if observations: - observations = self._create_observations(observations, infos, dones) - if self.env_done(): - self._step_type = dm_env.StepType.LAST - self._reset_next_step = True + discount_spec = self.discount_spec() discounts = { - agent: convert_np_type( - self.discount_spec()[agent].dtype, 0 - ) # Zero discount on final step + agent: convert_np_type(discount_spec[agent].dtype, 1) for agent in self.possible_agents } - self._update_stats(infos, rewards) - # TODO (Claude) zero discount! - else: - self._step_type = dm_env.StepType.MID - discounts = { - agent: convert_np_type( - self.discount_spec()[agent].dtype, 1 - ) # discount = 1 - for agent in self.possible_agents - } - - return ( - dm_env.TimeStep( - observation=observations, - reward=rewards, - discount=discounts, - step_type=self._step_type, - ), - {}, - ) - - # Convert Flatland observation so it's dm_env compatible. Also, the list - # of legal actions must be converted to a legal actions mask. - def _convert_observations( - self, observes: Dict[str, Tuple[np.ndarray, np.ndarray]], dones: Dict[str, bool] - ) -> Observation: - """Convert observation""" - return convert_dm_compatible_observations( - observes, - dones, - self.observation_spec(), - self.env_done(), - self.possible_agents, - ) + return parameterized_restart(rewards, discounts, observations), {} + + def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: + """Steps the environment.""" + self._pre_step() + + if self._reset_next_step: + return self.reset() + + self._agents = [ + agent + for agent in self.agents + if not self._environment.dones[get_agent_handle(agent)] + ] + + observations, rewards, dones, infos = self._environment.step(actions) + + rewards_spec = self.reward_spec() + # Handle empty rewards + if not rewards: + rewards = { + agent: convert_np_type(rewards_spec[agent].dtype, 0) + for agent in self.possible_agents + } + else: + rewards = { + get_agent_id(agent): convert_np_type( + rewards_spec[get_agent_id(agent)].dtype, reward + ) + for agent, reward in rewards.items() + } + + if observations: + observations = self._create_observations(observations, infos, dones) + + if self.env_done(): + self._step_type = dm_env.StepType.LAST + self._reset_next_step = True + discounts = { + agent: convert_np_type( + self.discount_spec()[agent].dtype, 0 + ) # Zero discount on final step + for agent in self.possible_agents + } + self._update_stats(infos, rewards) + # TODO (Claude) zero discount! + else: + self._step_type = dm_env.StepType.MID + discounts = { + agent: convert_np_type( + self.discount_spec()[agent].dtype, 1 + ) # discount = 1 + for agent in self.possible_agents + } + + return ( + dm_env.TimeStep( + observation=observations, + reward=rewards, + discount=discounts, + step_type=self._step_type, + ), + {}, + ) - # collate agent info and observation into a tuple, making the agents obervation to - # be a tuple of the observation from the env and the agent info - def _collate_obs_and_info( - self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] - ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: - """Combine observation and info.""" - observations: Dict = {} - observes = self.preprocessor(observes) - for agent, obs in observes.items(): - agent_id = get_agent_id(agent) - agent_info = np.array( - [info[k][agent] for k in sort_str_num(info.keys())], dtype=np.float32 + # Convert Flatland observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. + def _convert_observations( + self, + observes: Dict[str, Tuple[np.ndarray, np.ndarray]], + dones: Dict[str, bool], + ) -> Observation: + """Convert observation""" + return convert_dm_compatible_observations( + observes, + dones, + self.observation_spec(), + self.env_done(), + self.possible_agents, ) - new_obs = (obs, agent_info) if self._include_agent_info else obs - observations[agent_id] = new_obs - - return observations - - def _create_observations( - self, - obs: Dict[int, np.ndarray], - info: Dict[str, Dict[int, Any]], - dones: Dict[int, bool], - ) -> Observation: - """Convert observation.""" - observations_ = self._collate_obs_and_info(obs, info) - dones_ = {get_agent_id(k): v for k, v in dones.items()} - observations = self._convert_observations(observations_, dones_) - return observations - - def _obtain_preprocessor( - self, preprocessor: Any - ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: - """Obtains the actual preprocessor. - - Obtains the actual preprocessor to be used based on the supplied - preprocessor and the env's obs_builder object - """ - if not isinstance(self.obs_builder, GlobalObsForRailEnv): - _preprocessor = preprocessor if preprocessor else lambda x: x - if isinstance(self.obs_builder, TreeObsForRailEnv): - _preprocessor = ( - partial( - normalize_observation, tree_depth=self.obs_builder.max_depth - ) - if not preprocessor - else preprocessor - ) - assert _preprocessor is not None - else: - def _preprocessor( - x: Tuple[np.ndarray, np.ndarray, np.ndarray] - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - return x - - def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: - """Return preprocessor.""" - temp_obs = {} - for agent_id, ob in obs.items(): - temp_obs[agent_id] = _preprocessor(ob) - return temp_obs - - return returned_preprocessor - - # set all parameters that should be available before an environment step - # if no available agent, then environment is done and should be reset - def _pre_step(self) -> None: - """Pre-step.""" - if not self.agents: - self._step_type = dm_env.StepType.LAST - - def observation_spec(self) -> Dict[str, OLT]: - """Return observation spec.""" - observation_specs = {} - for agent in self.agents: - # Legal actions - action_spec = _convert_to_spec(self.action_spaces[agent]) - legals = np.ones(shape=action_spec.num_values, dtype=action_spec.dtype) - - observation_specs[agent] = OLT( - observation=tuple( - ( - _convert_to_spec(self.observation_spaces[agent]), - agent_info_spec(), + # collate agent info and observation into a tuple, + # making the agents obervation to + # be a tuple of the observation from the env and the agent info + def _collate_obs_and_info( + self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] + ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """Combine observation and info.""" + observations: Dict = {} + observes = self.preprocessor(observes) + for agent, obs in observes.items(): + agent_id = get_agent_id(agent) + agent_info = np.array( + [info[k][agent] for k in sort_str_num(info.keys())], + dtype=np.float32, + ) + new_obs = (obs, agent_info) if self._include_agent_info else obs + observations[agent_id] = new_obs + + return observations + + def _create_observations( + self, + obs: Dict[int, np.ndarray], + info: Dict[str, Dict[int, Any]], + dones: Dict[int, bool], + ) -> Observation: + """Convert observation.""" + observations_ = self._collate_obs_and_info(obs, info) + dones_ = {get_agent_id(k): v for k, v in dones.items()} + observations = self._convert_observations(observations_, dones_) + return observations + + def _obtain_preprocessor( + self, preprocessor: Any + ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: + """Obtains the actual preprocessor. + + Obtains the actual preprocessor to be used based on the supplied + preprocessor and the env's obs_builder object + """ + if not isinstance(self.obs_builder, GlobalObsForRailEnv): + _preprocessor = preprocessor if preprocessor else lambda x: x + if isinstance(self.obs_builder, TreeObsForRailEnv): + _preprocessor = ( + partial( + normalize_observation, tree_depth=self.obs_builder.max_depth + ) + if not preprocessor + else preprocessor + ) + assert _preprocessor is not None + else: + + def _preprocessor( + x: Tuple[np.ndarray, np.ndarray, np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return x + + def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: + """Return preprocessor.""" + temp_obs = {} + for agent_id, ob in obs.items(): + temp_obs[agent_id] = _preprocessor(ob) + return temp_obs + + return returned_preprocessor + + # set all parameters that should be available before an environment step + # if no available agent, then environment is done and should be reset + def _pre_step(self) -> None: + """Pre-step.""" + if not self.agents: + self._step_type = dm_env.StepType.LAST + + def observation_spec(self) -> Dict[str, OLT]: + """Return observation spec.""" + observation_specs = {} + for agent in self.agents: + # Legal actions + action_spec = _convert_to_spec(self.action_spaces[agent]) + legals = np.ones(shape=action_spec.num_values, dtype=action_spec.dtype) + + observation_specs[agent] = OLT( + observation=tuple( + ( + _convert_to_spec(self.observation_spaces[agent]), + agent_info_spec(), + ) ) + if self._include_agent_info + else _convert_to_spec(self.observation_spaces[agent]), + legal_actions=legals, + terminal=specs.Array((1,), np.float32), ) - if self._include_agent_info - else _convert_to_spec(self.observation_spaces[agent]), - legal_actions=legals, - terminal=specs.Array((1,), np.float32), + return observation_specs + + def action_spec( + self, + ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: + """Get action spec.""" + action_specs = {} + action_spaces = self.action_spaces + for agent in self.possible_agents: + action_specs[agent] = _convert_to_spec(action_spaces[agent]) + return action_specs + + def reward_spec(self) -> Dict[str, specs.Array]: + """Get the reward spec.""" + reward_specs = {} + for agent in self.possible_agents: + reward_specs[agent] = specs.Array((), np.float32) + return reward_specs + + def discount_spec(self) -> Dict[str, specs.BoundedArray]: + """Get the discount spec.""" + discount_specs = {} + for agent in self.possible_agents: + discount_specs[agent] = specs.BoundedArray( + (), np.float32, minimum=0, maximum=1.0 + ) + return discount_specs + + def extra_spec(self) -> Dict[str, specs.BoundedArray]: + """Get the extras spec.""" + return {} + + def seed(self, seed: int = None) -> None: + """Seed the environment.""" + self._environment._seed(seed) + + @property + def environment(self) -> RailEnv: + """Returns the wrapped environment.""" + return self._environment + + @property + def num_agents(self) -> int: + """Returns the number of trains/agents in the flatland environment""" + print(self._environment.number_of_agents) + return int(self._environment.number_of_agents) + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment.""" + return getattr(self._environment, name) + + # Utility functions + + def infer_observation_space( + obs: Union[tuple, np.ndarray, dict] + ) -> Union[Box, tuple, dict]: + """Infer a gym Observation space from a sample observation from flatland""" + if isinstance(obs, np.ndarray): + return Box( + -np.inf, + np.inf, + shape=obs.shape, + dtype=obs.dtype, ) - return observation_specs - - def action_spec(self) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: - """Get action spec.""" - action_specs = {} - action_spaces = self.action_spaces - for agent in self.possible_agents: - action_specs[agent] = _convert_to_spec(action_spaces[agent]) - return action_specs - - def reward_spec(self) -> Dict[str, specs.Array]: - """Get the reward spec.""" - reward_specs = {} - for agent in self.possible_agents: - reward_specs[agent] = specs.Array((), np.float32) - return reward_specs - - def discount_spec(self) -> Dict[str, specs.BoundedArray]: - """Get the discount spec.""" - discount_specs = {} - for agent in self.possible_agents: - discount_specs[agent] = specs.BoundedArray( - (), np.float32, minimum=0, maximum=1.0 + elif isinstance(obs, tuple): + return tuple(infer_observation_space(o) for o in obs) + elif isinstance(obs, dict): + return {key: infer_observation_space(value) for key, value in obs.items()} + else: + raise ValueError( + f"Unexpected observation type: {type(obs)}. " + f"Observation should be of either of this types " + f"(np.ndarray, tuple, or dict)" ) - return discount_specs - def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Get the extras spec.""" - return {} + def agent_info_spec() -> specs.BoundedArray: + """Create the spec for the agent_info part of the observation""" + return specs.BoundedArray((4,), dtype=np.float32, minimum=0.0, maximum=10) - def seed(self, seed: int = None) -> None: - """Seed the environment.""" - self._environment._seed(seed) + def get_agent_id(handle: int) -> str: + """Obtain the string that constitutes the agent id from an agent handle""" + return f"train_{handle}" - @property - def environment(self) -> RailEnv: - """Returns the wrapped environment.""" - return self._environment + def get_agent_handle(id: str) -> int: + """Obtain an agents handle given its id""" + return int(id.split("_")[-1]) - @property - def num_agents(self) -> int: - """Returns the number of trains/agents in the flatland environment""" - print(self._environment.number_of_agents) - return int(self._environment.number_of_agents) + def decorate_step_method(env: RailEnv) -> None: + """Step method decorator. - def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" - return getattr(self._environment, name) + Enable the step method of the env to take action dictionaries where + agent keys are the agent ids. Flatland uses the agent handles as + keys instead. This function decorates the step method so that it + accepts an action dict where the keys are the agent ids. + """ + env.step_ = env.step + def _step( + self: RailEnv, actions: Dict[str, Union[int, float, Any]] + ) -> dm_env.TimeStep: + actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} + return self.step_(actions_) -# Utility functions + env.step = tp.MethodType(_step, env) + # The block of code below is obtained from the flatland starter-kit + # at https://gitlab.aicrowd.com/flatland/flatland-starter-kit/-/blob/master/ + # utils/observation_utils.py + # this is done just to obtain the normalize_observation function that would + # serve as the default preprocessor for the Tree obs builder. -def infer_observation_space( - obs: Union[tuple, np.ndarray, dict] -) -> Union[Box, tuple, dict]: - """Infer a gym Observation space from a sample observation from flatland""" - if isinstance(obs, np.ndarray): - return Box( - -np.inf, - np.inf, - shape=obs.shape, - dtype=obs.dtype, - ) - elif isinstance(obs, tuple): - return tuple(infer_observation_space(o) for o in obs) - elif isinstance(obs, dict): - return {key: infer_observation_space(value) for key, value in obs.items()} - else: - raise ValueError( - f"Unexpected observation type: {type(obs)}. " - f"Observation should be of either of this types " - f"(np.ndarray, tuple, or dict)" - ) + def max_lt(seq: np.ndarray, val: Any) -> Any: + """Get max in sequence. + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + max = 0 + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max + + def min_gt(seq: np.ndarray, val: Any) -> Any: + """Gets min in a sequence. + + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] >= val and seq[idx] < min: + min = seq[idx] + idx -= 1 + return min + + def norm_obs_clip( + obs: np.ndarray, + clip_min: int = -1, + clip_max: int = 1, + fixed_radius: int = 0, + normalize_to_range: bool = False, + ) -> np.ndarray: + """Normalize observation. + + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + if fixed_radius > 0: + max_obs = fixed_radius + else: + max_obs = max(1, max_lt(obs, 1000)) + 1 + + min_obs = 0 # min(max_obs, min_gt(obs, 0)) + if normalize_to_range: + min_obs = min_gt(obs, 0) + if min_obs > max_obs: + min_obs = max_obs + if max_obs == min_obs: + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + + def _split_node_into_feature_groups( + node: Node, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Splits node into features.""" + data = np.zeros(6) + distance = np.zeros(1) + agent_data = np.zeros(4) + + data[0] = node.dist_own_target_encountered + data[1] = node.dist_other_target_encountered + data[2] = node.dist_other_agent_encountered + data[3] = node.dist_potential_conflict + data[4] = node.dist_unusable_switch + data[5] = node.dist_to_next_branch + + distance[0] = node.dist_min_to_target + + agent_data[0] = node.num_agents_same_direction + agent_data[1] = node.num_agents_opposite_direction + agent_data[2] = node.num_agents_malfunctioning + agent_data[3] = node.speed_min_fractional -def agent_info_spec() -> specs.BoundedArray: - """Create the spec for the agent_info part of the observation""" - return specs.BoundedArray((4,), dtype=np.float32, minimum=0.0, maximum=10) - - -def get_agent_id(handle: int) -> str: - """Obtain the string that constitutes the agent id from an agent handle - an int""" - return f"train_{handle}" - - -def get_agent_handle(id: str) -> int: - """Obtain an agents handle given its id""" - return int(id.split("_")[-1]) - - -def decorate_step_method(env: RailEnv) -> None: - """Step method decorator. - - Enable the step method of the env to take action dictionaries where agent keys - are the agent ids. Flatland uses the agent handles as keys instead. This function - decorates the step method so that it accepts an action dict where the keys are the - agent ids. - """ - env.step_ = env.step - - def _step( - self: RailEnv, actions: Dict[str, Union[int, float, Any]] - ) -> dm_env.TimeStep: - actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} - return self.step_(actions_) - - env.step = tp.MethodType(_step, env) - - -# The block of code below is obtained from the flatland starter-kit -# at https://gitlab.aicrowd.com/flatland/flatland-starter-kit/-/blob/master/ -# utils/observation_utils.py -# this is done just to obtain the normalize_observation function that would -# serve as the default preprocessor for the Tree obs builder. - - -def max_lt(seq: np.ndarray, val: Any) -> Any: - """Get max in sequence. - - Return greatest item in seq for which item < val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - max = 0 - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: - max = seq[idx] - idx -= 1 - return max - - -def min_gt(seq: np.ndarray, val: Any) -> Any: - """Gets min in a sequence. - - Return smallest item in seq for which item > val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - min = np.inf - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] >= val and seq[idx] < min: - min = seq[idx] - idx -= 1 - return min - - -def norm_obs_clip( - obs: np.ndarray, - clip_min: int = -1, - clip_max: int = 1, - fixed_radius: int = 0, - normalize_to_range: bool = False, -) -> np.ndarray: - """Normalize observation. - - This function returns the difference between min and max value of an observation - :param obs: Observation that should be normalized - :param clip_min: min value where observation will be clipped - :param clip_max: max value where observation will be clipped - :return: returnes normalized and clipped observatoin - """ - if fixed_radius > 0: - max_obs = fixed_radius - else: - max_obs = max(1, max_lt(obs, 1000)) + 1 - - min_obs = 0 # min(max_obs, min_gt(obs, 0)) - if normalize_to_range: - min_obs = min_gt(obs, 0) - if min_obs > max_obs: - min_obs = max_obs - if max_obs == min_obs: - return np.clip(np.array(obs) / max_obs, clip_min, clip_max) - norm = np.abs(max_obs - min_obs) - return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) - - -def _split_node_into_feature_groups( - node: Node, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Splits node into features.""" - data = np.zeros(6) - distance = np.zeros(1) - agent_data = np.zeros(4) - - data[0] = node.dist_own_target_encountered - data[1] = node.dist_other_target_encountered - data[2] = node.dist_other_agent_encountered - data[3] = node.dist_potential_conflict - data[4] = node.dist_unusable_switch - data[5] = node.dist_to_next_branch - - distance[0] = node.dist_min_to_target - - agent_data[0] = node.num_agents_same_direction - agent_data[1] = node.num_agents_opposite_direction - agent_data[2] = node.num_agents_malfunctioning - agent_data[3] = node.speed_min_fractional - - return data, distance, agent_data - - -def _split_subtree_into_feature_groups( - node: Node, current_tree_depth: int, max_tree_depth: int -) -> Tuple: - """Split subtree.""" - if node == -np.inf: - remaining_depth = max_tree_depth - current_tree_depth - # reference: - # https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure - num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) - return ( - [-np.inf] * num_remaining_nodes * 6, - [-np.inf] * num_remaining_nodes, - [-np.inf] * num_remaining_nodes * 4, - ) + return data, distance, agent_data - data, distance, agent_data = _split_node_into_feature_groups(node) + def _split_subtree_into_feature_groups( + node: Node, current_tree_depth: int, max_tree_depth: int + ) -> Tuple: + """Split subtree.""" + if node == -np.inf: + remaining_depth = max_tree_depth - current_tree_depth + # reference: + # https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure + num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) + return ( + [-np.inf] * num_remaining_nodes * 6, + [-np.inf] * num_remaining_nodes, + [-np.inf] * num_remaining_nodes * 4, + ) - if not node.childs: - return data, distance, agent_data + data, distance, agent_data = _split_node_into_feature_groups(node) - for direction in TreeObsForRailEnv.tree_explored_actions_char: - sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( - node.childs[direction], current_tree_depth + 1, max_tree_depth - ) - data = np.concatenate((data, sub_data)) - distance = np.concatenate((distance, sub_distance)) - agent_data = np.concatenate((agent_data, sub_agent_data)) + if not node.childs: + return data, distance, agent_data + + for direction in TreeObsForRailEnv.tree_explored_actions_char: + sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( + node.childs[direction], current_tree_depth + 1, max_tree_depth + ) + data = np.concatenate((data, sub_data)) + distance = np.concatenate((distance, sub_distance)) + agent_data = np.concatenate((agent_data, sub_agent_data)) + + return data, distance, agent_data - return data, distance, agent_data + def split_tree_into_feature_groups( + tree: Node, max_tree_depth: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """This function splits the tree into three difference arrays.""" + data, distance, agent_data = _split_node_into_feature_groups(tree) + for direction in TreeObsForRailEnv.tree_explored_actions_char: + sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( + tree.childs[direction], 1, max_tree_depth + ) + data = np.concatenate((data, sub_data)) + distance = np.concatenate((distance, sub_distance)) + agent_data = np.concatenate((agent_data, sub_agent_data)) -def split_tree_into_feature_groups( - tree: Node, max_tree_depth: int -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """This function splits the tree into three difference arrays.""" - data, distance, agent_data = _split_node_into_feature_groups(tree) + return data, distance, agent_data - for direction in TreeObsForRailEnv.tree_explored_actions_char: - sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( - tree.childs[direction], 1, max_tree_depth + def normalize_observation( + observation: Node, tree_depth: int, observation_radius: int = 0 + ) -> np.ndarray: + """This function normalizes the observation used by the RL algorithm.""" + if observation is None: + return np.zeros( + 11 * sum(np.power(4, i) for i in range(tree_depth + 1)), + dtype=np.float32, + ) + data, distance, agent_data = split_tree_into_feature_groups( + observation, tree_depth ) - data = np.concatenate((data, sub_data)) - distance = np.concatenate((distance, sub_distance)) - agent_data = np.concatenate((agent_data, sub_agent_data)) - return data, distance, agent_data + data = norm_obs_clip(data, fixed_radius=observation_radius) + distance = norm_obs_clip(distance, normalize_to_range=True) + agent_data = np.clip(agent_data, -1, 1) + normalized_obs = np.array( + np.concatenate((np.concatenate((data, distance)), agent_data)), + dtype=np.float32, + ) + return normalized_obs -def normalize_observation( - observation: Node, tree_depth: int, observation_radius: int = 0 -) -> np.ndarray: - """This function normalizes the observation used by the RL algorithm.""" - if observation is None: - return np.zeros( - 11 * sum(np.power(4, i) for i in range(tree_depth + 1)), dtype=np.float32 - ) - data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth) - - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance, normalize_to_range=True) - agent_data = np.clip(agent_data, -1, 1) - normalized_obs = np.array( - np.concatenate((np.concatenate((data, distance)), agent_data)), dtype=np.float32 - ) - return normalized_obs +except ModuleNotFoundError: + # Incase users have not installed Flatland + pass diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index 3bca332f6..cecc0bd03 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -19,304 +19,312 @@ import dm_env import numpy as np from acme import specs -from smac.env import StarCraft2Env from mava import types from mava.utils.wrapper_utils import convert_np_type, parameterized_restart from mava.wrappers.env_wrappers import ParallelEnvWrapper +try: # noqa + from smac.env import StarCraft2Env + + class SMACWrapper(ParallelEnvWrapper): + """Environment wrapper for PettingZoo MARL environments.""" + + def __init__( + self, + environment: StarCraft2Env, + return_state_info: bool = True, + ): + """Constructor for parallel PZ wrapper. + + Args: + environment (ParallelEnv): parallel PZ env. + env_preprocess_wrappers (Optional[List], optional): Wrappers + that preprocess envs. + Format (env_preprocessor, dict_with_preprocessor_params). + """ + self._environment = environment + self._return_state_info = return_state_info + self._agents = [f"agent_{n}" for n in range(self._environment.n_agents)] -class SMACWrapper(ParallelEnvWrapper): - """Environment wrapper for PettingZoo MARL environments.""" - - def __init__( - self, - environment: StarCraft2Env, - return_state_info: bool = True, - ): - """Constructor for parallel PZ wrapper. - - Args: - environment (ParallelEnv): parallel PZ env. - env_preprocess_wrappers (Optional[List], optional): Wrappers - that preprocess envs. - Format (env_preprocessor, dict_with_preprocessor_params). - """ - self._environment = environment - self._return_state_info = return_state_info - self._agents = [f"agent_{n}" for n in range(self._environment.n_agents)] - - self._reset_next_step = True - self._done = False - - def reset(self) -> dm_env.TimeStep: - """Resets the env. - - Returns: - dm_env.TimeStep: dm timestep. - """ - # Reset the environment - self._environment.reset() - self._done = False - - self._reset_next_step = False - self._step_type = dm_env.StepType.FIRST - - # Get observation from env - observation = self.environment.get_obs() - legal_actions = self._get_legal_actions() - observations = self._convert_observations( - observation, legal_actions, self._done - ) - - # Set env discount to 1 for all agents - discount_spec = self.discount_spec() - self._discounts = { - agent: convert_np_type(discount_spec[agent].dtype, 1) - for agent in self._agents - } - - # Set reward to zero for all agents - rewards_spec = self.reward_spec() - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) - for agent in self._agents - } - - # Possibly add state information to extras - if self._return_state_info: - state = self.get_state() - extras = {"s_t": state} - else: - extras = {} - - return parameterized_restart(rewards, self._discounts, observations), extras - - def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: - """Steps in env. - - Args: - actions (Dict[str, np.ndarray]): actions per agent. - - Returns: - dm_env.TimeStep: dm timestep - """ - # Possibly reset the environment - if self._reset_next_step: - return self.reset() - - # Convert dict of actions to list for SMAC - smac_actions = list(actions.values()) - - # Step the SMAC environment - reward, self._done, self._info = self._environment.step(smac_actions) - - # Get the next observations - next_observations = self._environment.get_obs() - legal_actions = self._get_legal_actions() - next_observations = self._convert_observations( - next_observations, legal_actions, self._done - ) - - # Convert team reward to agent-wise rewards - rewards = self._convert_reward(reward) - - # Possibly add state information to extras - if self._return_state_info: - state = self.get_state() - extras = {"s_t": state} - else: - extras = {} - - if self._done: - self._step_type = dm_env.StepType.LAST self._reset_next_step = True + self._done = False + + def reset(self) -> dm_env.TimeStep: + """Resets the env. + + Returns: + dm_env.TimeStep: dm timestep. + """ + # Reset the environment + self._environment.reset() + self._done = False + + self._reset_next_step = False + self._step_type = dm_env.StepType.FIRST + + # Get observation from env + observation = self.environment.get_obs() + legal_actions = self._get_legal_actions() + observations = self._convert_observations( + observation, legal_actions, self._done + ) - # Discount on last timestep set to zero + # Set env discount to 1 for all agents + discount_spec = self.discount_spec() self._discounts = { - agent: convert_np_type(self.discount_spec()[agent].dtype, 0.0) + agent: convert_np_type(discount_spec[agent].dtype, 1) for agent in self._agents } - else: - self._step_type = dm_env.StepType.MID - - # Create timestep object - timestep = dm_env.TimeStep( - observation=next_observations, - reward=rewards, - discount=self._discounts, - step_type=self._step_type, - ) - - return timestep, extras - - def env_done(self) -> bool: - """Check if env is done. - - Returns: - bool: bool indicating if env is done. - """ - return self._done - - def _convert_reward(self, reward: float) -> Dict[str, float]: - """Convert rewards to be dm_env compatible. - - Args: - rewards: rewards per agent. - """ - rewards_spec = self.reward_spec() - rewards = {} - for agent in self._agents: - rewards[agent] = convert_np_type(rewards_spec[agent].dtype, reward) - return rewards - - def _get_legal_actions(self) -> List: - """Get legal actions from the environment.""" - legal_actions = [] - for i, _ in enumerate(self._agents): - legal_actions.append( - np.array(self._environment.get_avail_agent_actions(i), dtype="int") - ) - return legal_actions - - def _convert_observations( - self, observations: List, legal_actions: List, done: bool - ) -> types.Observation: - """Convert SMAC observation so it's dm_env compatible. - - Args: - observes (Dict[str, np.ndarray]): observations per agent. - dones (Dict[str, bool]): dones per agent. - - Returns: - types.Observation: dm compatible observations. - """ - olt_observations = {} - for i, agent in enumerate(self._agents): - - olt_observations[agent] = types.OLT( - observation=observations[i], - legal_actions=legal_actions[i], - terminal=np.asarray([done], dtype=np.float32), - ) - - return olt_observations - - def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Function returns extra spec (format) of the env. - Returns: - Dict[str, specs.BoundedArray]: extra spec. - """ - if self._return_state_info: - return {"s_t": self._environment.get_state()} - else: - return {} + # Set reward to zero for all agents + rewards_spec = self.reward_spec() + rewards = { + agent: convert_np_type(rewards_spec[agent].dtype, 0) + for agent in self._agents + } - def observation_spec(self) -> Dict[str, types.OLT]: - """Observation spec. + # Possibly add state information to extras + if self._return_state_info: + state = self.get_state() + extras = {"s_t": state} + else: + extras = {} - Returns: - types.Observation: spec for environment. - """ - self._environment.reset() + return parameterized_restart(rewards, self._discounts, observations), extras - observations = self._environment.get_obs() - legal_actions = self._get_legal_actions() + def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: + """Steps in env. - observation_specs = {} - for i, agent in enumerate(self._agents): + Args: + actions (Dict[str, np.ndarray]): actions per agent. - observation_specs[agent] = types.OLT( - observation=observations[i], - legal_actions=legal_actions[i], - terminal=np.asarray([True], dtype=np.float32), - ) + Returns: + dm_env.TimeStep: dm timestep + """ + # Possibly reset the environment + if self._reset_next_step: + return self.reset() - return observation_specs + # Convert dict of actions to list for SMAC + smac_actions = list(actions.values()) - def action_spec(self) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: - """Action spec. + # Step the SMAC environment + reward, self._done, self._info = self._environment.step(smac_actions) - Returns: - Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: spec for actions. - """ - action_specs = {} - for agent in self._agents: - action_specs[agent] = specs.DiscreteArray( - num_values=self._environment.n_actions, dtype=int + # Get the next observations + next_observations = self._environment.get_obs() + legal_actions = self._get_legal_actions() + next_observations = self._convert_observations( + next_observations, legal_actions, self._done ) - return action_specs - - def reward_spec(self) -> Dict[str, specs.Array]: - """Reward spec. - - Returns: - Dict[str, specs.Array]: spec for rewards. - """ - reward_specs = {} - for agent in self._agents: - reward_specs[agent] = specs.Array((), np.float32) - return reward_specs - - def discount_spec(self) -> Dict[str, specs.BoundedArray]: - """Discount spec. - - Returns: - Dict[str, specs.BoundedArray]: spec for discounts. - """ - discount_specs = {} - for agent in self._agents: - discount_specs[agent] = specs.BoundedArray( - (), np.float32, minimum=0, maximum=1.0 + + # Convert team reward to agent-wise rewards + rewards = self._convert_reward(reward) + + # Possibly add state information to extras + if self._return_state_info: + state = self.get_state() + extras = {"s_t": state} + else: + extras = {} + + if self._done: + self._step_type = dm_env.StepType.LAST + self._reset_next_step = True + + # Discount on last timestep set to zero + self._discounts = { + agent: convert_np_type(self.discount_spec()[agent].dtype, 0.0) + for agent in self._agents + } + else: + self._step_type = dm_env.StepType.MID + + # Create timestep object + timestep = dm_env.TimeStep( + observation=next_observations, + reward=rewards, + discount=self._discounts, + step_type=self._step_type, ) - return discount_specs - - def get_stats(self) -> Optional[Dict]: - """Return extra stats to be logged. - - Returns: - extra stats to be logged. - """ - return self._environment.get_stats() - - @property - def agents(self) -> List: - """Agents still alive in env (not done). - - Returns: - List: alive agents in env. - """ - return self._agents - - @property - def possible_agents(self) -> List: - """All possible agents in env. - - Returns: - List: all possible agents in env. - """ - return self._agents - - @property - def environment(self) -> StarCraft2Env: - """Returns the wrapped environment. - - Returns: - ParallelEnv: parallel env. - """ - return self._environment - - def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment. - - Args: - name (str): attribute. - - Returns: - Any: return attribute from env or underlying env. - """ - if hasattr(self.__class__, name): - return self.__getattribute__(name) - else: - return getattr(self._environment, name) + + return timestep, extras + + def env_done(self) -> bool: + """Check if env is done. + + Returns: + bool: bool indicating if env is done. + """ + return self._done + + def _convert_reward(self, reward: float) -> Dict[str, float]: + """Convert rewards to be dm_env compatible. + + Args: + rewards: rewards per agent. + """ + rewards_spec = self.reward_spec() + rewards = {} + for agent in self._agents: + rewards[agent] = convert_np_type(rewards_spec[agent].dtype, reward) + return rewards + + def _get_legal_actions(self) -> List: + """Get legal actions from the environment.""" + legal_actions = [] + for i, _ in enumerate(self._agents): + legal_actions.append( + np.array(self._environment.get_avail_agent_actions(i), dtype="int") + ) + return legal_actions + + def _convert_observations( + self, observations: List, legal_actions: List, done: bool + ) -> types.Observation: + """Convert SMAC observation so it's dm_env compatible. + + Args: + observes (Dict[str, np.ndarray]): observations per agent. + dones (Dict[str, bool]): dones per agent. + + Returns: + types.Observation: dm compatible observations. + """ + olt_observations = {} + for i, agent in enumerate(self._agents): + + olt_observations[agent] = types.OLT( + observation=observations[i], + legal_actions=legal_actions[i], + terminal=np.asarray([done], dtype=np.float32), + ) + + return olt_observations + + def extra_spec(self) -> Dict[str, specs.BoundedArray]: + """Function returns extra spec (format) of the env. + + Returns: + Dict[str, specs.BoundedArray]: extra spec. + """ + if self._return_state_info: + return {"s_t": self._environment.get_state()} + else: + return {} + + def observation_spec(self) -> Dict[str, types.OLT]: + """Observation spec. + + Returns: + types.Observation: spec for environment. + """ + self._environment.reset() + + observations = self._environment.get_obs() + legal_actions = self._get_legal_actions() + + observation_specs = {} + for i, agent in enumerate(self._agents): + + observation_specs[agent] = types.OLT( + observation=observations[i], + legal_actions=legal_actions[i], + terminal=np.asarray([True], dtype=np.float32), + ) + + return observation_specs + + def action_spec( + self, + ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: + """Action spec. + + Returns: + spec for actions. + """ + action_specs = {} + for agent in self._agents: + action_specs[agent] = specs.DiscreteArray( + num_values=self._environment.n_actions, dtype=int + ) + return action_specs + + def reward_spec(self) -> Dict[str, specs.Array]: + """Reward spec. + + Returns: + Dict[str, specs.Array]: spec for rewards. + """ + reward_specs = {} + for agent in self._agents: + reward_specs[agent] = specs.Array((), np.float32) + return reward_specs + + def discount_spec(self) -> Dict[str, specs.BoundedArray]: + """Discount spec. + + Returns: + Dict[str, specs.BoundedArray]: spec for discounts. + """ + discount_specs = {} + for agent in self._agents: + discount_specs[agent] = specs.BoundedArray( + (), np.float32, minimum=0, maximum=1.0 + ) + return discount_specs + + def get_stats(self) -> Optional[Dict]: + """Return extra stats to be logged. + + Returns: + extra stats to be logged. + """ + return self._environment.get_stats() + + @property + def agents(self) -> List: + """Agents still alive in env (not done). + + Returns: + List: alive agents in env. + """ + return self._agents + + @property + def possible_agents(self) -> List: + """All possible agents in env. + + Returns: + List: all possible agents in env. + """ + return self._agents + + @property + def environment(self) -> StarCraft2Env: + """Returns the wrapped environment. + + Returns: + ParallelEnv: parallel env. + """ + return self._environment + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment. + + Args: + name (str): attribute. + + Returns: + Any: return attribute from env or underlying env. + """ + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) + + +except ModuleNotFoundError: + # Incase users have not installed SMAC + pass From afb5eabcd6e97fb3921eea61f4fd523a56161559 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 12:10:40 +0200 Subject: [PATCH 42/56] Small typo in RAEDME. --- mava/systems/tf/value_decomposition/README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mava/systems/tf/value_decomposition/README.md b/mava/systems/tf/value_decomposition/README.md index 87df69eb9..e2886b50e 100644 --- a/mava/systems/tf/value_decomposition/README.md +++ b/mava/systems/tf/value_decomposition/README.md @@ -1,7 +1,8 @@ # Value Decomposition Methods eg. VDN and QMIX -This system supports to important Value Decomposition methods, VDN and QMIX. +This system supports two important Value Decomposition methods, VDN and QMIX. The design of the system also allows the user to easily include their own mixer in place of the two supported ones. -[Sunehag et al., 2017]: https://arxiv.org/abs/1706.05296 -[Rashid et al., 2018]: https://arxiv.org/abs/1803.11485 +VDN, (Sunehag et al., 2017), https://arxiv.org/abs/1706.05296 + +QMIX, (Rashid et al., 2018), https://arxiv.org/abs/1803.11485 From 18c156a9c7f18de8bfcb757ffb7afa403799a6d8 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 12:15:45 +0200 Subject: [PATCH 43/56] Fix mixer docstrings. --- mava/components/tf/modules/mixing/mixers.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/mava/components/tf/modules/mixing/mixers.py b/mava/components/tf/modules/mixing/mixers.py index 7039e3f7c..21e11f456 100644 --- a/mava/components/tf/modules/mixing/mixers.py +++ b/mava/components/tf/modules/mixing/mixers.py @@ -10,9 +10,11 @@ class BaseMixer(snt.Module): """ def __init__(self) -> None: + """Initialise base mixer.""" super().__init__() def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor) -> tf.Tensor: + """Call method.""" return agent_qs @@ -20,19 +22,13 @@ class VDN(BaseMixer): """VDN mixing network.""" def __init__(self) -> None: + """Initialise VDN mixer.""" super().__init__() def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor) -> tf.Tensor: + """Call method.""" return tf.reduce_sum(agent_qs, axis=-1, keepdims=True) - """Initialize VDN class - Args: - agent_qs: Tensor containing the q-values of actions chosen by agents - states: Tensor containing global environment state. - Returns: - Tensor with total q-value. - """ - class QMIX(BaseMixer): """QMIX mixing network.""" @@ -45,8 +41,9 @@ def __init__( Args: num_agents: Number of agents in the enviroment state_dim: Dimensions of the global environment state - embed_dim: TODO (Ruan): Cluade please add - hypernet_embed: TODO (Ruan): Claude Please add + embed_dim: The dimension of the output of the first layer + of the mixer. + hypernet_embed: Number of units in the hyper network """ super().__init__() @@ -73,6 +70,7 @@ def __init__( self.V = snt.Sequential([snt.Linear(self.embed_dim), tf.nn.relu, snt.Linear(1)]) def __call__(self, agent_qs: tf.Tensor, states: tf.Tensor) -> tf.Tensor: + """Call method.""" bs = agent_qs.shape[1] state_dim = states.shape[-1] From f59545ca76b0ed47ccbb4c050192433e8a722df8 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 12:43:22 +0200 Subject: [PATCH 44/56] Fix import error in test. --- tests/conftest.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 459e6957c..78e5265da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,13 +23,9 @@ import numpy.testing as npt import pytest -try: - from mava.utils.environments import flatland_utils - from mava.wrappers.flatland import FlatlandEnvWrapper +from mava.utils.environments import flatland_utils +from mava.wrappers.flatland import FlatlandEnvWrapper - _has_flatland = True -except ModuleNotFoundError: - _has_flatland = False try: from pettingzoo.utils.env import AECEnv, ParallelEnv except ModuleNotFoundError: @@ -57,22 +53,21 @@ SequentialMADiscreteEnvironment, ) -if _has_flatland: - # flatland environment config - flatland_env_config = { - "n_agents": 3, - "x_dim": 30, - "y_dim": 30, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "seed": 0, - "malfunction_rate": 1 / 200, - "malfunction_min_duration": 20, - "malfunction_max_duration": 50, - "observation_max_path_depth": 30, - "observation_tree_depth": 2, - } +# flatland environment config +flatland_env_config = { + "n_agents": 3, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "seed": 0, + "malfunction_rate": 1 / 200, + "malfunction_min_duration": 20, + "malfunction_max_duration": 50, + "observation_max_path_depth": 30, + "observation_tree_depth": 2, +} """ Helpers contains re-usable test functions. From 7dd5c206de2f9db0be26902bf9ac1672ebb569cf Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 12:51:00 +0200 Subject: [PATCH 45/56] More import fixes for flatland and smac. --- mava/wrappers/flatland.py | 1129 +++++++++++++++++++------------------ mava/wrappers/smac.py | 569 +++++++++---------- tests/conftest.py | 7 +- 3 files changed, 852 insertions(+), 853 deletions(-) diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index ddd462269..851f4a58d 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -22,6 +22,10 @@ import numpy as np from acme import specs from acme.wrappers.gym_wrapper import _convert_to_spec +from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.step_utils.states import TrainState +from flatland.utils.rendertools import AgentRenderVariant, RenderTool from gym.spaces import Discrete from gym.spaces.box import Box @@ -34,612 +38,609 @@ ) from mava.wrappers.env_wrappers import ParallelEnvWrapper -try: # noqa - from flatland.envs.observations import GlobalObsForRailEnv, Node, TreeObsForRailEnv - from flatland.envs.rail_env import RailEnv - from flatland.envs.step_utils.states import TrainState - from flatland.utils.rendertools import AgentRenderVariant, RenderTool - - class FlatlandEnvWrapper(ParallelEnvWrapper): - """Environment wrapper for Flatland environments. - - All environments would require an observation preprocessor, except for - 'GlobalObsForRailEnv'. This is because flatland gives users the - flexibility of designing custom observation builders. 'TreeObsForRailEnv' - would use the normalize_observation function from the flatland baselines - if none is supplied. - The supplied preprocessor should return either an array, tuple of arrays or - a dictionary of arrays for an observation input. - The obervation, for an agent, returned by this wrapper could consist of both - the agent observation and agent info. This is because flatland also provides - informationn about the agents at each step. This information include; - 'action_required', 'malfunction', 'speed', and 'status', and it can be appended - to the observation, by this wrapper, as an array. action_required is a boolean, - malfunction is an int denoting the number of steps for which the agent would - remain motionless, speed is a float and status can be any of the below; - READY_TO_DEPART = 0 - ACTIVE = 1 - DONE = 2 - DONE_REMOVED = 3 - This would be included in the observation if agent_info is set to True - """ - - # Note: we don't inherit from base.EnvironmentWrapper because that class - # assumes that the wrapped environment is a dm_env.Environment. - def __init__( - self, - environment: RailEnv, - preprocessor: Callable[ - [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] - ] = None, - agent_info: bool = False, - ): - """Wrap Flatland environment. - - Args: - environment: underlying RailEnv - preprocessor: optional preprocessor. Defaults to None. - agent_info: include agent info. Defaults to True. - """ - self._environment = environment - decorate_step_method(self._environment) - - self._agents = [get_agent_id(i) for i in range(self.num_agents)] - self._possible_agents = self.agents[:] - self._reset_next_step = True - self._step_type = dm_env.StepType.FIRST - self.num_actions = 5 +class FlatlandEnvWrapper(ParallelEnvWrapper): + """Environment wrapper for Flatland environments. + + All environments would require an observation preprocessor, except for + 'GlobalObsForRailEnv'. This is because flatland gives users the + flexibility of designing custom observation builders. 'TreeObsForRailEnv' + would use the normalize_observation function from the flatland baselines + if none is supplied. + The supplied preprocessor should return either an array, tuple of arrays or + a dictionary of arrays for an observation input. + The obervation, for an agent, returned by this wrapper could consist of both + the agent observation and agent info. This is because flatland also provides + informationn about the agents at each step. This information include; + 'action_required', 'malfunction', 'speed', and 'status', and it can be appended + to the observation, by this wrapper, as an array. action_required is a boolean, + malfunction is an int denoting the number of steps for which the agent would + remain motionless, speed is a float and status can be any of the below; + READY_TO_DEPART = 0 + ACTIVE = 1 + DONE = 2 + DONE_REMOVED = 3 + This would be included in the observation if agent_info is set to True + """ + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + def __init__( + self, + environment: RailEnv, + preprocessor: Callable[ + [Any], Union[np.ndarray, Tuple[np.ndarray], Dict[str, np.ndarray]] + ] = None, + agent_info: bool = False, + ): + """Wrap Flatland environment. + + Args: + environment: underlying RailEnv + preprocessor: optional preprocessor. Defaults to None. + agent_info: include agent info. Defaults to True. + """ + self._environment = environment + decorate_step_method(self._environment) + + self._agents = [get_agent_id(i) for i in range(self.num_agents)] + self._possible_agents = self.agents[:] + + self._reset_next_step = True + self._step_type = dm_env.StepType.FIRST + self.num_actions = 5 + + self.action_spaces = { + agent: Discrete(self.num_actions) for agent in self.possible_agents + } + + # preprocessor must be for observation builders other than global obs + # treeobs builders would use the default preprocessor if none is + # supplied + self.preprocessor: Callable[ + [Dict[int, Any]], Dict[int, Any] + ] = self._obtain_preprocessor(preprocessor) + + self._include_agent_info = agent_info + + # observation space: + # flatland defines no observation space for an agent. Here we try + # to define the observation space. All agents are identical and would + # have the same observation space. + # Infer observation space based on returned observation + obs, _ = self._environment.reset() + obs = self.preprocessor(obs) + self.observation_spaces = { + get_agent_id(i): infer_observation_space(ob) for i, ob in obs.items() + } + + self._env_renderer = RenderTool( + self._environment, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, + screen_height=800, # Adjust these parameters to fit your resolution + screen_width=800, + ) # Adjust these parameters to fit your resolution + + @property + def agents(self) -> List[str]: + """Return list of active agents.""" + return self._agents + + @property + def possible_agents(self) -> List[str]: + """Return list of all possible agents.""" + return self._possible_agents + + def _update_stats(self, info: Dict, rewards: Dict) -> None: + """Update flatland stats.""" + episode_return = sum(list(rewards.values())) + tasks_finished = sum( + [1 if state == TrainState.DONE else 0 for state in info["state"].values()] + ) + completion = tasks_finished / len(self._agents) + normalized_score = episode_return / ( + self._environment._max_episode_steps * len(self._agents) + ) - self.action_spaces = { - agent: Discrete(self.num_actions) for agent in self.possible_agents - } + self._latest_score = normalized_score + self._latest_completion = completion - # preprocessor must be for observation builders other than global obs - # treeobs builders would use the default preprocessor if none is - # supplied - self.preprocessor: Callable[ - [Dict[int, Any]], Dict[int, Any] - ] = self._obtain_preprocessor(preprocessor) - - self._include_agent_info = agent_info - - # observation space: - # flatland defines no observation space for an agent. Here we try - # to define the observation space. All agents are identical and would - # have the same observation space. - # Infer observation space based on returned observation - obs, _ = self._environment.reset() - obs = self.preprocessor(obs) - self.observation_spaces = { - get_agent_id(i): infer_observation_space(ob) for i, ob in obs.items() + def get_stats(self) -> Dict: + """Get flatland specific stats.""" + if self._latest_completion is not None and self._latest_score is not None: + return { + "score": self._latest_score, + "completion": self._latest_completion, } + else: + return {} - self._env_renderer = RenderTool( - self._environment, - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, - screen_height=800, # Adjust these parameters to fit your resolution - screen_width=800, - ) # Adjust these parameters to fit your resolution - - @property - def agents(self) -> List[str]: - """Return list of active agents.""" - return self._agents - - @property - def possible_agents(self) -> List[str]: - """Return list of all possible agents.""" - return self._possible_agents - - def _update_stats(self, info: Dict, rewards: Dict) -> None: - """Update flatland stats.""" - episode_return = sum(list(rewards.values())) - tasks_finished = sum( - [ - 1 if state == TrainState.DONE else 0 - for state in info["state"].values() - ] - ) - completion = tasks_finished / len(self._agents) - normalized_score = episode_return / ( - self._environment._max_episode_steps * len(self._agents) - ) - - self._latest_score = normalized_score - self._latest_completion = completion - - def get_stats(self) -> Dict: - """Get flatland specific stats.""" - if self._latest_completion is not None and self._latest_score is not None: - return { - "score": self._latest_score, - "completion": self._latest_completion, - } - else: - return {} - - def render(self, mode: str = "human") -> np.ndarray: - """Renders the environment.""" - if mode == "human": - show = True - else: - show = False - - return self._env_renderer.render_env( - show=show, - show_observations=False, - show_predictions=False, - return_image=True, - ) - - def env_done(self) -> bool: - """Checks if the environment is done.""" - return self._environment.dones["__all__"] or not self.agents + def render(self, mode: str = "human") -> np.ndarray: + """Renders the environment.""" + if mode == "human": + show = True + else: + show = False - def reset(self) -> dm_env.TimeStep: - """Resets the episode.""" - # Reset the rendering sytem - self._env_renderer.reset() + return self._env_renderer.render_env( + show=show, + show_observations=False, + show_predictions=False, + return_image=True, + ) - self._reset_next_step = False - self._agents = self.possible_agents[:] - self._discounts = { - agent: np.dtype("float32").type(1.0) for agent in self.agents - } - observe, info = self._environment.reset() - observations = self._create_observations( - observe, info, self._environment.dones - ) - rewards_spec = self.reward_spec() + def env_done(self) -> bool: + """Checks if the environment is done.""" + return self._environment.dones["__all__"] or not self.agents + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + # Reset the rendering sytem + self._env_renderer.reset() + + self._reset_next_step = False + self._agents = self.possible_agents[:] + self._discounts = { + agent: np.dtype("float32").type(1.0) for agent in self.agents + } + observe, info = self._environment.reset() + observations = self._create_observations(observe, info, self._environment.dones) + rewards_spec = self.reward_spec() + rewards = { + agent: convert_np_type(rewards_spec[agent].dtype, 0) + for agent in self.possible_agents + } + + discount_spec = self.discount_spec() + discounts = { + agent: convert_np_type(discount_spec[agent].dtype, 1) + for agent in self.possible_agents + } + return parameterized_restart(rewards, discounts, observations), {} + + def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: + """Steps the environment.""" + self._pre_step() + + if self._reset_next_step: + return self.reset() + + self._agents = [ + agent + for agent in self.agents + if not self._environment.dones[get_agent_handle(agent)] + ] + + observations, rewards, dones, infos = self._environment.step(actions) + + rewards_spec = self.reward_spec() + # Handle empty rewards + if not rewards: rewards = { agent: convert_np_type(rewards_spec[agent].dtype, 0) for agent in self.possible_agents } + else: + rewards = { + get_agent_id(agent): convert_np_type( + rewards_spec[get_agent_id(agent)].dtype, reward + ) + for agent, reward in rewards.items() + } - discount_spec = self.discount_spec() + if observations: + observations = self._create_observations(observations, infos, dones) + + if self.env_done(): + self._step_type = dm_env.StepType.LAST + self._reset_next_step = True discounts = { - agent: convert_np_type(discount_spec[agent].dtype, 1) + agent: convert_np_type( + self.discount_spec()[agent].dtype, 0 + ) # Zero discount on final step + for agent in self.possible_agents + } + self._update_stats(infos, rewards) + # TODO (Claude) zero discount! + else: + self._step_type = dm_env.StepType.MID + discounts = { + agent: convert_np_type( + self.discount_spec()[agent].dtype, 1 + ) # discount = 1 for agent in self.possible_agents } - return parameterized_restart(rewards, discounts, observations), {} - - def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: - """Steps the environment.""" - self._pre_step() - - if self._reset_next_step: - return self.reset() - - self._agents = [ - agent - for agent in self.agents - if not self._environment.dones[get_agent_handle(agent)] - ] - - observations, rewards, dones, infos = self._environment.step(actions) - - rewards_spec = self.reward_spec() - # Handle empty rewards - if not rewards: - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) - for agent in self.possible_agents - } - else: - rewards = { - get_agent_id(agent): convert_np_type( - rewards_spec[get_agent_id(agent)].dtype, reward - ) - for agent, reward in rewards.items() - } - - if observations: - observations = self._create_observations(observations, infos, dones) - - if self.env_done(): - self._step_type = dm_env.StepType.LAST - self._reset_next_step = True - discounts = { - agent: convert_np_type( - self.discount_spec()[agent].dtype, 0 - ) # Zero discount on final step - for agent in self.possible_agents - } - self._update_stats(infos, rewards) - # TODO (Claude) zero discount! - else: - self._step_type = dm_env.StepType.MID - discounts = { - agent: convert_np_type( - self.discount_spec()[agent].dtype, 1 - ) # discount = 1 - for agent in self.possible_agents - } - - return ( - dm_env.TimeStep( - observation=observations, - reward=rewards, - discount=discounts, - step_type=self._step_type, - ), - {}, - ) - # Convert Flatland observation so it's dm_env compatible. Also, the list - # of legal actions must be converted to a legal actions mask. - def _convert_observations( - self, - observes: Dict[str, Tuple[np.ndarray, np.ndarray]], - dones: Dict[str, bool], - ) -> Observation: - """Convert observation""" - return convert_dm_compatible_observations( - observes, - dones, - self.observation_spec(), - self.env_done(), - self.possible_agents, - ) + return ( + dm_env.TimeStep( + observation=observations, + reward=rewards, + discount=discounts, + step_type=self._step_type, + ), + {}, + ) - # collate agent info and observation into a tuple, - # making the agents obervation to - # be a tuple of the observation from the env and the agent info - def _collate_obs_and_info( - self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] - ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: - """Combine observation and info.""" - observations: Dict = {} - observes = self.preprocessor(observes) - for agent, obs in observes.items(): - agent_id = get_agent_id(agent) - agent_info = np.array( - [info[k][agent] for k in sort_str_num(info.keys())], - dtype=np.float32, - ) - new_obs = (obs, agent_info) if self._include_agent_info else obs - observations[agent_id] = new_obs - - return observations - - def _create_observations( - self, - obs: Dict[int, np.ndarray], - info: Dict[str, Dict[int, Any]], - dones: Dict[int, bool], - ) -> Observation: - """Convert observation.""" - observations_ = self._collate_obs_and_info(obs, info) - dones_ = {get_agent_id(k): v for k, v in dones.items()} - observations = self._convert_observations(observations_, dones_) - return observations - - def _obtain_preprocessor( - self, preprocessor: Any - ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: - """Obtains the actual preprocessor. - - Obtains the actual preprocessor to be used based on the supplied - preprocessor and the env's obs_builder object - """ - if not isinstance(self.obs_builder, GlobalObsForRailEnv): - _preprocessor = preprocessor if preprocessor else lambda x: x - if isinstance(self.obs_builder, TreeObsForRailEnv): - _preprocessor = ( - partial( - normalize_observation, tree_depth=self.obs_builder.max_depth - ) - if not preprocessor - else preprocessor - ) - assert _preprocessor is not None - else: - - def _preprocessor( - x: Tuple[np.ndarray, np.ndarray, np.ndarray] - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - return x - - def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: - """Return preprocessor.""" - temp_obs = {} - for agent_id, ob in obs.items(): - temp_obs[agent_id] = _preprocessor(ob) - return temp_obs - - return returned_preprocessor - - # set all parameters that should be available before an environment step - # if no available agent, then environment is done and should be reset - def _pre_step(self) -> None: - """Pre-step.""" - if not self.agents: - self._step_type = dm_env.StepType.LAST - - def observation_spec(self) -> Dict[str, OLT]: - """Return observation spec.""" - observation_specs = {} - for agent in self.agents: - # Legal actions - action_spec = _convert_to_spec(self.action_spaces[agent]) - legals = np.ones(shape=action_spec.num_values, dtype=action_spec.dtype) - - observation_specs[agent] = OLT( - observation=tuple( - ( - _convert_to_spec(self.observation_spaces[agent]), - agent_info_spec(), - ) + # Convert Flatland observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. + def _convert_observations( + self, + observes: Dict[str, Tuple[np.ndarray, np.ndarray]], + dones: Dict[str, bool], + ) -> Observation: + """Convert observation""" + return convert_dm_compatible_observations( + observes, + dones, + self.observation_spec(), + self.env_done(), + self.possible_agents, + ) + + # collate agent info and observation into a tuple, + # making the agents obervation to + # be a tuple of the observation from the env and the agent info + def _collate_obs_and_info( + self, observes: Dict[int, np.ndarray], info: Dict[str, Dict[int, Any]] + ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """Combine observation and info.""" + observations: Dict = {} + observes = self.preprocessor(observes) + for agent, obs in observes.items(): + agent_id = get_agent_id(agent) + agent_info = np.array( + [info[k][agent] for k in sort_str_num(info.keys())], + dtype=np.float32, + ) + new_obs = (obs, agent_info) if self._include_agent_info else obs + observations[agent_id] = new_obs + + return observations + + def _create_observations( + self, + obs: Dict[int, np.ndarray], + info: Dict[str, Dict[int, Any]], + dones: Dict[int, bool], + ) -> Observation: + """Convert observation.""" + observations_ = self._collate_obs_and_info(obs, info) + dones_ = {get_agent_id(k): v for k, v in dones.items()} + observations = self._convert_observations(observations_, dones_) + return observations + + def _obtain_preprocessor( + self, preprocessor: Any + ) -> Callable[[Dict[int, Any]], Dict[int, np.ndarray]]: + """Obtains the actual preprocessor. + + Obtains the actual preprocessor to be used based on the supplied + preprocessor and the env's obs_builder object + """ + if not isinstance(self.obs_builder, GlobalObsForRailEnv): + _preprocessor = preprocessor if preprocessor else lambda x: x + if isinstance(self.obs_builder, TreeObsForRailEnv): + _preprocessor = ( + partial( + normalize_observation, tree_depth=self.obs_builder.max_depth ) - if self._include_agent_info - else _convert_to_spec(self.observation_spaces[agent]), - legal_actions=legals, - terminal=specs.Array((1,), np.float32), + if not preprocessor + else preprocessor ) - return observation_specs - - def action_spec( - self, - ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: - """Get action spec.""" - action_specs = {} - action_spaces = self.action_spaces - for agent in self.possible_agents: - action_specs[agent] = _convert_to_spec(action_spaces[agent]) - return action_specs - - def reward_spec(self) -> Dict[str, specs.Array]: - """Get the reward spec.""" - reward_specs = {} - for agent in self.possible_agents: - reward_specs[agent] = specs.Array((), np.float32) - return reward_specs - - def discount_spec(self) -> Dict[str, specs.BoundedArray]: - """Get the discount spec.""" - discount_specs = {} - for agent in self.possible_agents: - discount_specs[agent] = specs.BoundedArray( - (), np.float32, minimum=0, maximum=1.0 - ) - return discount_specs - - def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Get the extras spec.""" - return {} + assert _preprocessor is not None + else: - def seed(self, seed: int = None) -> None: - """Seed the environment.""" - self._environment._seed(seed) - - @property - def environment(self) -> RailEnv: - """Returns the wrapped environment.""" - return self._environment - - @property - def num_agents(self) -> int: - """Returns the number of trains/agents in the flatland environment""" - print(self._environment.number_of_agents) - return int(self._environment.number_of_agents) - - def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment.""" - return getattr(self._environment, name) - - # Utility functions - - def infer_observation_space( - obs: Union[tuple, np.ndarray, dict] - ) -> Union[Box, tuple, dict]: - """Infer a gym Observation space from a sample observation from flatland""" - if isinstance(obs, np.ndarray): - return Box( - -np.inf, - np.inf, - shape=obs.shape, - dtype=obs.dtype, + def _preprocessor( + x: Tuple[np.ndarray, np.ndarray, np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return x + + def returned_preprocessor(obs: Dict[int, Any]) -> Dict[int, np.ndarray]: + """Return preprocessor.""" + temp_obs = {} + for agent_id, ob in obs.items(): + temp_obs[agent_id] = _preprocessor(ob) + return temp_obs + + return returned_preprocessor + + # set all parameters that should be available before an environment step + # if no available agent, then environment is done and should be reset + def _pre_step(self) -> None: + """Pre-step.""" + if not self.agents: + self._step_type = dm_env.StepType.LAST + + def observation_spec(self) -> Dict[str, OLT]: + """Return observation spec.""" + observation_specs = {} + for agent in self.agents: + # Legal actions + action_spec = _convert_to_spec(self.action_spaces[agent]) + legals = np.ones(shape=action_spec.num_values, dtype=action_spec.dtype) + + observation_specs[agent] = OLT( + observation=tuple( + ( + _convert_to_spec(self.observation_spaces[agent]), + agent_info_spec(), + ) + ) + if self._include_agent_info + else _convert_to_spec(self.observation_spaces[agent]), + legal_actions=legals, + terminal=specs.Array((1,), np.float32), ) - elif isinstance(obs, tuple): - return tuple(infer_observation_space(o) for o in obs) - elif isinstance(obs, dict): - return {key: infer_observation_space(value) for key, value in obs.items()} - else: - raise ValueError( - f"Unexpected observation type: {type(obs)}. " - f"Observation should be of either of this types " - f"(np.ndarray, tuple, or dict)" + return observation_specs + + def action_spec( + self, + ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: + """Get action spec.""" + action_specs = {} + action_spaces = self.action_spaces + for agent in self.possible_agents: + action_specs[agent] = _convert_to_spec(action_spaces[agent]) + return action_specs + + def reward_spec(self) -> Dict[str, specs.Array]: + """Get the reward spec.""" + reward_specs = {} + for agent in self.possible_agents: + reward_specs[agent] = specs.Array((), np.float32) + return reward_specs + + def discount_spec(self) -> Dict[str, specs.BoundedArray]: + """Get the discount spec.""" + discount_specs = {} + for agent in self.possible_agents: + discount_specs[agent] = specs.BoundedArray( + (), np.float32, minimum=0, maximum=1.0 ) + return discount_specs - def agent_info_spec() -> specs.BoundedArray: - """Create the spec for the agent_info part of the observation""" - return specs.BoundedArray((4,), dtype=np.float32, minimum=0.0, maximum=10) + def extra_spec(self) -> Dict[str, specs.BoundedArray]: + """Get the extras spec.""" + return {} - def get_agent_id(handle: int) -> str: - """Obtain the string that constitutes the agent id from an agent handle""" - return f"train_{handle}" + def seed(self, seed: int = None) -> None: + """Seed the environment.""" + self._environment._seed(seed) - def get_agent_handle(id: str) -> int: - """Obtain an agents handle given its id""" - return int(id.split("_")[-1]) + @property + def environment(self) -> RailEnv: + """Returns the wrapped environment.""" + return self._environment - def decorate_step_method(env: RailEnv) -> None: - """Step method decorator. + @property + def num_agents(self) -> int: + """Returns the number of trains/agents in the flatland environment""" + print(self._environment.number_of_agents) + return int(self._environment.number_of_agents) - Enable the step method of the env to take action dictionaries where - agent keys are the agent ids. Flatland uses the agent handles as - keys instead. This function decorates the step method so that it - accepts an action dict where the keys are the agent ids. - """ - env.step_ = env.step + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment.""" + return getattr(self._environment, name) - def _step( - self: RailEnv, actions: Dict[str, Union[int, float, Any]] - ) -> dm_env.TimeStep: - actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} - return self.step_(actions_) - env.step = tp.MethodType(_step, env) +# Utility functions - # The block of code below is obtained from the flatland starter-kit - # at https://gitlab.aicrowd.com/flatland/flatland-starter-kit/-/blob/master/ - # utils/observation_utils.py - # this is done just to obtain the normalize_observation function that would - # serve as the default preprocessor for the Tree obs builder. - def max_lt(seq: np.ndarray, val: Any) -> Any: - """Get max in sequence. +def infer_observation_space( + obs: Union[tuple, np.ndarray, dict] +) -> Union[Box, tuple, dict]: + """Infer a gym Observation space from a sample observation from flatland""" + if isinstance(obs, np.ndarray): + return Box( + -np.inf, + np.inf, + shape=obs.shape, + dtype=obs.dtype, + ) + elif isinstance(obs, tuple): + return tuple(infer_observation_space(o) for o in obs) + elif isinstance(obs, dict): + return {key: infer_observation_space(value) for key, value in obs.items()} + else: + raise ValueError( + f"Unexpected observation type: {type(obs)}. " + f"Observation should be of either of this types " + f"(np.ndarray, tuple, or dict)" + ) - Return greatest item in seq for which item < val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - max = 0 - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: - max = seq[idx] - idx -= 1 - return max - - def min_gt(seq: np.ndarray, val: Any) -> Any: - """Gets min in a sequence. - - Return smallest item in seq for which item > val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - min = np.inf - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] >= val and seq[idx] < min: - min = seq[idx] - idx -= 1 - return min - - def norm_obs_clip( - obs: np.ndarray, - clip_min: int = -1, - clip_max: int = 1, - fixed_radius: int = 0, - normalize_to_range: bool = False, - ) -> np.ndarray: - """Normalize observation. - - This function returns the difference between min and max value of an observation - :param obs: Observation that should be normalized - :param clip_min: min value where observation will be clipped - :param clip_max: max value where observation will be clipped - :return: returnes normalized and clipped observatoin - """ - if fixed_radius > 0: - max_obs = fixed_radius - else: - max_obs = max(1, max_lt(obs, 1000)) + 1 - - min_obs = 0 # min(max_obs, min_gt(obs, 0)) - if normalize_to_range: - min_obs = min_gt(obs, 0) - if min_obs > max_obs: - min_obs = max_obs - if max_obs == min_obs: - return np.clip(np.array(obs) / max_obs, clip_min, clip_max) - norm = np.abs(max_obs - min_obs) - return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) - - def _split_node_into_feature_groups( - node: Node, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Splits node into features.""" - data = np.zeros(6) - distance = np.zeros(1) - agent_data = np.zeros(4) - - data[0] = node.dist_own_target_encountered - data[1] = node.dist_other_target_encountered - data[2] = node.dist_other_agent_encountered - data[3] = node.dist_potential_conflict - data[4] = node.dist_unusable_switch - data[5] = node.dist_to_next_branch - - distance[0] = node.dist_min_to_target - - agent_data[0] = node.num_agents_same_direction - agent_data[1] = node.num_agents_opposite_direction - agent_data[2] = node.num_agents_malfunctioning - agent_data[3] = node.speed_min_fractional - return data, distance, agent_data +def agent_info_spec() -> specs.BoundedArray: + """Create the spec for the agent_info part of the observation""" + return specs.BoundedArray((4,), dtype=np.float32, minimum=0.0, maximum=10) + + +def get_agent_id(handle: int) -> str: + """Obtain the string that constitutes the agent id from an agent handle""" + return f"train_{handle}" + + +def get_agent_handle(id: str) -> int: + """Obtain an agents handle given its id""" + return int(id.split("_")[-1]) + + +def decorate_step_method(env: RailEnv) -> None: + """Step method decorator. + + Enable the step method of the env to take action dictionaries where + agent keys are the agent ids. Flatland uses the agent handles as + keys instead. This function decorates the step method so that it + accepts an action dict where the keys are the agent ids. + """ + env.step_ = env.step + + def _step( + self: RailEnv, actions: Dict[str, Union[int, float, Any]] + ) -> dm_env.TimeStep: + actions_ = {get_agent_handle(k): int(v) for k, v in actions.items()} + return self.step_(actions_) + + env.step = tp.MethodType(_step, env) + + +# The block of code below is obtained from the flatland starter-kit +# at https://gitlab.aicrowd.com/flatland/flatland-starter-kit/-/blob/master/ +# utils/observation_utils.py +# this is done just to obtain the normalize_observation function that would +# serve as the default preprocessor for the Tree obs builder. + + +def max_lt(seq: np.ndarray, val: Any) -> Any: + """Get max in sequence. + + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + max = 0 + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: + max = seq[idx] + idx -= 1 + return max + + +def min_gt(seq: np.ndarray, val: Any) -> Any: + """Gets min in a sequence. + + Return smallest item in seq for which item > val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + min = np.inf + idx = len(seq) - 1 + while idx >= 0: + if seq[idx] >= val and seq[idx] < min: + min = seq[idx] + idx -= 1 + return min + + +def norm_obs_clip( + obs: np.ndarray, + clip_min: int = -1, + clip_max: int = 1, + fixed_radius: int = 0, + normalize_to_range: bool = False, +) -> np.ndarray: + """Normalize observation. + + This function returns the difference between min and max value of an observation + :param obs: Observation that should be normalized + :param clip_min: min value where observation will be clipped + :param clip_max: max value where observation will be clipped + :return: returnes normalized and clipped observatoin + """ + if fixed_radius > 0: + max_obs = fixed_radius + else: + max_obs = max(1, max_lt(obs, 1000)) + 1 + + min_obs = 0 # min(max_obs, min_gt(obs, 0)) + if normalize_to_range: + min_obs = min_gt(obs, 0) + if min_obs > max_obs: + min_obs = max_obs + if max_obs == min_obs: + return np.clip(np.array(obs) / max_obs, clip_min, clip_max) + norm = np.abs(max_obs - min_obs) + return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) + + +def _split_node_into_feature_groups( + node: Node, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Splits node into features.""" + data = np.zeros(6) + distance = np.zeros(1) + agent_data = np.zeros(4) + + data[0] = node.dist_own_target_encountered + data[1] = node.dist_other_target_encountered + data[2] = node.dist_other_agent_encountered + data[3] = node.dist_potential_conflict + data[4] = node.dist_unusable_switch + data[5] = node.dist_to_next_branch + + distance[0] = node.dist_min_to_target + + agent_data[0] = node.num_agents_same_direction + agent_data[1] = node.num_agents_opposite_direction + agent_data[2] = node.num_agents_malfunctioning + agent_data[3] = node.speed_min_fractional + + return data, distance, agent_data + + +def _split_subtree_into_feature_groups( + node: Node, current_tree_depth: int, max_tree_depth: int +) -> Tuple: + """Split subtree.""" + if node == -np.inf: + remaining_depth = max_tree_depth - current_tree_depth + # reference: + # https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure + num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) + return ( + [-np.inf] * num_remaining_nodes * 6, + [-np.inf] * num_remaining_nodes, + [-np.inf] * num_remaining_nodes * 4, + ) - def _split_subtree_into_feature_groups( - node: Node, current_tree_depth: int, max_tree_depth: int - ) -> Tuple: - """Split subtree.""" - if node == -np.inf: - remaining_depth = max_tree_depth - current_tree_depth - # reference: - # https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure - num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) - return ( - [-np.inf] * num_remaining_nodes * 6, - [-np.inf] * num_remaining_nodes, - [-np.inf] * num_remaining_nodes * 4, - ) + data, distance, agent_data = _split_node_into_feature_groups(node) - data, distance, agent_data = _split_node_into_feature_groups(node) + if not node.childs: + return data, distance, agent_data - if not node.childs: - return data, distance, agent_data + for direction in TreeObsForRailEnv.tree_explored_actions_char: + sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( + node.childs[direction], current_tree_depth + 1, max_tree_depth + ) + data = np.concatenate((data, sub_data)) + distance = np.concatenate((distance, sub_distance)) + agent_data = np.concatenate((agent_data, sub_agent_data)) - for direction in TreeObsForRailEnv.tree_explored_actions_char: - sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( - node.childs[direction], current_tree_depth + 1, max_tree_depth - ) - data = np.concatenate((data, sub_data)) - distance = np.concatenate((distance, sub_distance)) - agent_data = np.concatenate((agent_data, sub_agent_data)) + return data, distance, agent_data - return data, distance, agent_data - def split_tree_into_feature_groups( - tree: Node, max_tree_depth: int - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """This function splits the tree into three difference arrays.""" - data, distance, agent_data = _split_node_into_feature_groups(tree) +def split_tree_into_feature_groups( + tree: Node, max_tree_depth: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """This function splits the tree into three difference arrays.""" + data, distance, agent_data = _split_node_into_feature_groups(tree) - for direction in TreeObsForRailEnv.tree_explored_actions_char: - sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( - tree.childs[direction], 1, max_tree_depth - ) - data = np.concatenate((data, sub_data)) - distance = np.concatenate((distance, sub_distance)) - agent_data = np.concatenate((agent_data, sub_agent_data)) + for direction in TreeObsForRailEnv.tree_explored_actions_char: + sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups( + tree.childs[direction], 1, max_tree_depth + ) + data = np.concatenate((data, sub_data)) + distance = np.concatenate((distance, sub_distance)) + agent_data = np.concatenate((agent_data, sub_agent_data)) - return data, distance, agent_data + return data, distance, agent_data - def normalize_observation( - observation: Node, tree_depth: int, observation_radius: int = 0 - ) -> np.ndarray: - """This function normalizes the observation used by the RL algorithm.""" - if observation is None: - return np.zeros( - 11 * sum(np.power(4, i) for i in range(tree_depth + 1)), - dtype=np.float32, - ) - data, distance, agent_data = split_tree_into_feature_groups( - observation, tree_depth - ) - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance, normalize_to_range=True) - agent_data = np.clip(agent_data, -1, 1) - normalized_obs = np.array( - np.concatenate((np.concatenate((data, distance)), agent_data)), +def normalize_observation( + observation: Node, tree_depth: int, observation_radius: int = 0 +) -> np.ndarray: + """This function normalizes the observation used by the RL algorithm.""" + if observation is None: + return np.zeros( + 11 * sum(np.power(4, i) for i in range(tree_depth + 1)), dtype=np.float32, ) - return normalized_obs - - -except ModuleNotFoundError: - # Incase users have not installed Flatland - pass + data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth) + + data = norm_obs_clip(data, fixed_radius=observation_radius) + distance = norm_obs_clip(distance, normalize_to_range=True) + agent_data = np.clip(agent_data, -1, 1) + normalized_obs = np.array( + np.concatenate((np.concatenate((data, distance)), agent_data)), + dtype=np.float32, + ) + return normalized_obs diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index cecc0bd03..bfd3f77d3 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -19,312 +19,307 @@ import dm_env import numpy as np from acme import specs +from smac.env import StarCraft2Env from mava import types from mava.utils.wrapper_utils import convert_np_type, parameterized_restart from mava.wrappers.env_wrappers import ParallelEnvWrapper -try: # noqa - from smac.env import StarCraft2Env - - class SMACWrapper(ParallelEnvWrapper): - """Environment wrapper for PettingZoo MARL environments.""" - - def __init__( - self, - environment: StarCraft2Env, - return_state_info: bool = True, - ): - """Constructor for parallel PZ wrapper. - - Args: - environment (ParallelEnv): parallel PZ env. - env_preprocess_wrappers (Optional[List], optional): Wrappers - that preprocess envs. - Format (env_preprocessor, dict_with_preprocessor_params). - """ - self._environment = environment - self._return_state_info = return_state_info - self._agents = [f"agent_{n}" for n in range(self._environment.n_agents)] +class SMACWrapper(ParallelEnvWrapper): + """Environment wrapper for PettingZoo MARL environments.""" + + def __init__( + self, + environment: StarCraft2Env, + return_state_info: bool = True, + ): + """Constructor for parallel PZ wrapper. + + Args: + environment (ParallelEnv): parallel PZ env. + env_preprocess_wrappers (Optional[List], optional): Wrappers + that preprocess envs. + Format (env_preprocessor, dict_with_preprocessor_params). + return_state_info: return extra state info + """ + self._environment = environment + self._return_state_info = return_state_info + self._agents = [f"agent_{n}" for n in range(self._environment.n_agents)] + + self._reset_next_step = True + self._done = False + + def reset(self) -> dm_env.TimeStep: + """Resets the env. + + Returns: + dm_env.TimeStep: dm timestep. + """ + # Reset the environment + self._environment.reset() + self._done = False + + self._reset_next_step = False + self._step_type = dm_env.StepType.FIRST + + # Get observation from env + observation = self.environment.get_obs() + legal_actions = self._get_legal_actions() + observations = self._convert_observations( + observation, legal_actions, self._done + ) + + # Set env discount to 1 for all agents + discount_spec = self.discount_spec() + self._discounts = { + agent: convert_np_type(discount_spec[agent].dtype, 1) + for agent in self._agents + } + + # Set reward to zero for all agents + rewards_spec = self.reward_spec() + rewards = { + agent: convert_np_type(rewards_spec[agent].dtype, 0) + for agent in self._agents + } + + # Possibly add state information to extras + if self._return_state_info: + state = self.get_state() + extras = {"s_t": state} + else: + extras = {} + + return parameterized_restart(rewards, self._discounts, observations), extras + + def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: + """Steps in env. + + Args: + actions (Dict[str, np.ndarray]): actions per agent. + + Returns: + dm_env.TimeStep: dm timestep + """ + # Possibly reset the environment + if self._reset_next_step: + return self.reset() + + # Convert dict of actions to list for SMAC + smac_actions = list(actions.values()) + + # Step the SMAC environment + reward, self._done, self._info = self._environment.step(smac_actions) + + # Get the next observations + next_observations = self._environment.get_obs() + legal_actions = self._get_legal_actions() + next_observations = self._convert_observations( + next_observations, legal_actions, self._done + ) + + # Convert team reward to agent-wise rewards + rewards = self._convert_reward(reward) + + # Possibly add state information to extras + if self._return_state_info: + state = self.get_state() + extras = {"s_t": state} + else: + extras = {} + + if self._done: + self._step_type = dm_env.StepType.LAST self._reset_next_step = True - self._done = False - - def reset(self) -> dm_env.TimeStep: - """Resets the env. - - Returns: - dm_env.TimeStep: dm timestep. - """ - # Reset the environment - self._environment.reset() - self._done = False - - self._reset_next_step = False - self._step_type = dm_env.StepType.FIRST - - # Get observation from env - observation = self.environment.get_obs() - legal_actions = self._get_legal_actions() - observations = self._convert_observations( - observation, legal_actions, self._done - ) - # Set env discount to 1 for all agents - discount_spec = self.discount_spec() + # Discount on last timestep set to zero self._discounts = { - agent: convert_np_type(discount_spec[agent].dtype, 1) - for agent in self._agents - } - - # Set reward to zero for all agents - rewards_spec = self.reward_spec() - rewards = { - agent: convert_np_type(rewards_spec[agent].dtype, 0) + agent: convert_np_type(self.discount_spec()[agent].dtype, 0.0) for agent in self._agents } + else: + self._step_type = dm_env.StepType.MID + + # Create timestep object + timestep = dm_env.TimeStep( + observation=next_observations, + reward=rewards, + discount=self._discounts, + step_type=self._step_type, + ) + + return timestep, extras + + def env_done(self) -> bool: + """Check if env is done. + + Returns: + bool: bool indicating if env is done. + """ + return self._done + + def _convert_reward(self, reward: float) -> Dict[str, float]: + """Convert rewards to be dm_env compatible. + + Args: + reward: rewards per agent. + """ + rewards_spec = self.reward_spec() + rewards = {} + for agent in self._agents: + rewards[agent] = convert_np_type(rewards_spec[agent].dtype, reward) + return rewards + + def _get_legal_actions(self) -> List: + """Get legal actions from the environment.""" + legal_actions = [] + for i, _ in enumerate(self._agents): + legal_actions.append( + np.array(self._environment.get_avail_agent_actions(i), dtype="int") + ) + return legal_actions + + def _convert_observations( + self, observations: List, legal_actions: List, done: bool + ) -> types.Observation: + """Convert SMAC observation so it's dm_env compatible. + + Args: + observes (Dict[str, np.ndarray]): observations per agent. + dones (Dict[str, bool]): dones per agent. + + Returns: + types.Observation: dm compatible observations. + """ + olt_observations = {} + for i, agent in enumerate(self._agents): + + olt_observations[agent] = types.OLT( + observation=observations[i], + legal_actions=legal_actions[i], + terminal=np.asarray([done], dtype=np.float32), + ) - # Possibly add state information to extras - if self._return_state_info: - state = self.get_state() - extras = {"s_t": state} - else: - extras = {} + return olt_observations - return parameterized_restart(rewards, self._discounts, observations), extras + def extra_spec(self) -> Dict[str, specs.BoundedArray]: + """Function returns extra spec (format) of the env. - def step(self, actions: Dict[str, np.ndarray]) -> dm_env.TimeStep: - """Steps in env. + Returns: + Dict[str, specs.BoundedArray]: extra spec. + """ + if self._return_state_info: + return {"s_t": self._environment.get_state()} + else: + return {} - Args: - actions (Dict[str, np.ndarray]): actions per agent. + def observation_spec(self) -> Dict[str, types.OLT]: + """Observation spec. - Returns: - dm_env.TimeStep: dm timestep - """ - # Possibly reset the environment - if self._reset_next_step: - return self.reset() + Returns: + types.Observation: spec for environment. + """ + self._environment.reset() - # Convert dict of actions to list for SMAC - smac_actions = list(actions.values()) + observations = self._environment.get_obs() + legal_actions = self._get_legal_actions() - # Step the SMAC environment - reward, self._done, self._info = self._environment.step(smac_actions) + observation_specs = {} + for i, agent in enumerate(self._agents): - # Get the next observations - next_observations = self._environment.get_obs() - legal_actions = self._get_legal_actions() - next_observations = self._convert_observations( - next_observations, legal_actions, self._done + observation_specs[agent] = types.OLT( + observation=observations[i], + legal_actions=legal_actions[i], + terminal=np.asarray([True], dtype=np.float32), ) - # Convert team reward to agent-wise rewards - rewards = self._convert_reward(reward) - - # Possibly add state information to extras - if self._return_state_info: - state = self.get_state() - extras = {"s_t": state} - else: - extras = {} - - if self._done: - self._step_type = dm_env.StepType.LAST - self._reset_next_step = True - - # Discount on last timestep set to zero - self._discounts = { - agent: convert_np_type(self.discount_spec()[agent].dtype, 0.0) - for agent in self._agents - } - else: - self._step_type = dm_env.StepType.MID - - # Create timestep object - timestep = dm_env.TimeStep( - observation=next_observations, - reward=rewards, - discount=self._discounts, - step_type=self._step_type, - ) + return observation_specs + + def action_spec( + self, + ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: + """Action spec. - return timestep, extras - - def env_done(self) -> bool: - """Check if env is done. - - Returns: - bool: bool indicating if env is done. - """ - return self._done - - def _convert_reward(self, reward: float) -> Dict[str, float]: - """Convert rewards to be dm_env compatible. - - Args: - rewards: rewards per agent. - """ - rewards_spec = self.reward_spec() - rewards = {} - for agent in self._agents: - rewards[agent] = convert_np_type(rewards_spec[agent].dtype, reward) - return rewards - - def _get_legal_actions(self) -> List: - """Get legal actions from the environment.""" - legal_actions = [] - for i, _ in enumerate(self._agents): - legal_actions.append( - np.array(self._environment.get_avail_agent_actions(i), dtype="int") - ) - return legal_actions - - def _convert_observations( - self, observations: List, legal_actions: List, done: bool - ) -> types.Observation: - """Convert SMAC observation so it's dm_env compatible. - - Args: - observes (Dict[str, np.ndarray]): observations per agent. - dones (Dict[str, bool]): dones per agent. - - Returns: - types.Observation: dm compatible observations. - """ - olt_observations = {} - for i, agent in enumerate(self._agents): - - olt_observations[agent] = types.OLT( - observation=observations[i], - legal_actions=legal_actions[i], - terminal=np.asarray([done], dtype=np.float32), - ) - - return olt_observations - - def extra_spec(self) -> Dict[str, specs.BoundedArray]: - """Function returns extra spec (format) of the env. - - Returns: - Dict[str, specs.BoundedArray]: extra spec. - """ - if self._return_state_info: - return {"s_t": self._environment.get_state()} - else: - return {} - - def observation_spec(self) -> Dict[str, types.OLT]: - """Observation spec. - - Returns: - types.Observation: spec for environment. - """ - self._environment.reset() - - observations = self._environment.get_obs() - legal_actions = self._get_legal_actions() - - observation_specs = {} - for i, agent in enumerate(self._agents): - - observation_specs[agent] = types.OLT( - observation=observations[i], - legal_actions=legal_actions[i], - terminal=np.asarray([True], dtype=np.float32), - ) - - return observation_specs - - def action_spec( - self, - ) -> Dict[str, Union[specs.DiscreteArray, specs.BoundedArray]]: - """Action spec. - - Returns: - spec for actions. - """ - action_specs = {} - for agent in self._agents: - action_specs[agent] = specs.DiscreteArray( - num_values=self._environment.n_actions, dtype=int - ) - return action_specs - - def reward_spec(self) -> Dict[str, specs.Array]: - """Reward spec. - - Returns: - Dict[str, specs.Array]: spec for rewards. - """ - reward_specs = {} - for agent in self._agents: - reward_specs[agent] = specs.Array((), np.float32) - return reward_specs - - def discount_spec(self) -> Dict[str, specs.BoundedArray]: - """Discount spec. - - Returns: - Dict[str, specs.BoundedArray]: spec for discounts. - """ - discount_specs = {} - for agent in self._agents: - discount_specs[agent] = specs.BoundedArray( - (), np.float32, minimum=0, maximum=1.0 - ) - return discount_specs - - def get_stats(self) -> Optional[Dict]: - """Return extra stats to be logged. - - Returns: - extra stats to be logged. - """ - return self._environment.get_stats() - - @property - def agents(self) -> List: - """Agents still alive in env (not done). - - Returns: - List: alive agents in env. - """ - return self._agents - - @property - def possible_agents(self) -> List: - """All possible agents in env. - - Returns: - List: all possible agents in env. - """ - return self._agents - - @property - def environment(self) -> StarCraft2Env: - """Returns the wrapped environment. - - Returns: - ParallelEnv: parallel env. - """ - return self._environment - - def __getattr__(self, name: str) -> Any: - """Expose any other attributes of the underlying environment. - - Args: - name (str): attribute. - - Returns: - Any: return attribute from env or underlying env. - """ - if hasattr(self.__class__, name): - return self.__getattribute__(name) - else: - return getattr(self._environment, name) - - -except ModuleNotFoundError: - # Incase users have not installed SMAC - pass + Returns: + spec for actions. + """ + action_specs = {} + for agent in self._agents: + action_specs[agent] = specs.DiscreteArray( + num_values=self._environment.n_actions, dtype=int + ) + return action_specs + + def reward_spec(self) -> Dict[str, specs.Array]: + """Reward spec. + + Returns: + Dict[str, specs.Array]: spec for rewards. + """ + reward_specs = {} + for agent in self._agents: + reward_specs[agent] = specs.Array((), np.float32) + return reward_specs + + def discount_spec(self) -> Dict[str, specs.BoundedArray]: + """Discount spec. + + Returns: + Dict[str, specs.BoundedArray]: spec for discounts. + """ + discount_specs = {} + for agent in self._agents: + discount_specs[agent] = specs.BoundedArray( + (), np.float32, minimum=0, maximum=1.0 + ) + return discount_specs + + def get_stats(self) -> Optional[Dict]: + """Return extra stats to be logged. + + Returns: + extra stats to be logged. + """ + return self._environment.get_stats() + + @property + def agents(self) -> List: + """Agents still alive in env (not done). + + Returns: + List: alive agents in env. + """ + return self._agents + + @property + def possible_agents(self) -> List: + """All possible agents in env. + + Returns: + List: all possible agents in env. + """ + return self._agents + + @property + def environment(self) -> StarCraft2Env: + """Returns the wrapped environment. + + Returns: + ParallelEnv: parallel env. + """ + return self._environment + + def __getattr__(self, name: str) -> Any: + """Expose any other attributes of the underlying environment. + + Args: + name (str): attribute. + + Returns: + Any: return attribute from env or underlying env. + """ + if hasattr(self.__class__, name): + return self.__getattribute__(name) + else: + return getattr(self._environment, name) diff --git a/tests/conftest.py b/tests/conftest.py index 78e5265da..9b8992c6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,8 +23,11 @@ import numpy.testing as npt import pytest -from mava.utils.environments import flatland_utils -from mava.wrappers.flatland import FlatlandEnvWrapper +try: + from mava.utils.environments import flatland_utils + from mava.wrappers.flatland import FlatlandEnvWrapper +except ModuleNotFoundError: + pass try: from pettingzoo.utils.env import AECEnv, ParallelEnv From 51a94ab5434155b005ecf2d2cf59edaed1df0362 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 2 Feb 2022 13:12:22 +0200 Subject: [PATCH 46/56] Reformating error. --- examples/flatland/recurrent/centralised/run_vdn.py | 2 +- examples/flatland/recurrent/decentralised/run_madqn.py | 1 - .../atari/pong/recurrent/decentralised/run_madqn.py | 2 +- .../recurrent/decentralised/run_madqn_scale_trainers.py | 1 - mava/systems/tf/madqn/execution.py | 9 +++++---- mava/systems/tf/madqn/system.py | 2 +- mava/systems/tf/madqn/training.py | 2 +- mava/systems/tf/value_decomposition/system.py | 7 +++---- mava/systems/tf/value_decomposition/training.py | 6 +++--- 9 files changed, 15 insertions(+), 17 deletions(-) diff --git a/examples/flatland/recurrent/centralised/run_vdn.py b/examples/flatland/recurrent/centralised/run_vdn.py index acf4a39a2..ce5f5b6cf 100644 --- a/examples/flatland/recurrent/centralised/run_vdn.py +++ b/examples/flatland/recurrent/centralised/run_vdn.py @@ -30,7 +30,6 @@ from mava.utils.environments.flatland_utils import make_environment from mava.utils.loggers import logger_utils - FLAGS = flags.FLAGS flags.DEFINE_string( @@ -58,6 +57,7 @@ def main(_: Any) -> None: + """Run example.""" # Environment. environment_factory = functools.partial(make_environment, **env_config) diff --git a/examples/flatland/recurrent/decentralised/run_madqn.py b/examples/flatland/recurrent/decentralised/run_madqn.py index 0d1824a11..3c1d340c2 100644 --- a/examples/flatland/recurrent/decentralised/run_madqn.py +++ b/examples/flatland/recurrent/decentralised/run_madqn.py @@ -19,7 +19,6 @@ from typing import Any, Dict import launchpad as lp -import sonnet as snt from absl import app, flags from mava.components.tf.modules.exploration.exploration_scheduling import ( diff --git a/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py b/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py index 0bab836e3..be53a215c 100644 --- a/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py +++ b/examples/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py @@ -19,7 +19,6 @@ from typing import Any import launchpad as lp -import sonnet as snt from absl import app, flags from mava.components.tf.modules.exploration import LinearExplorationScheduler @@ -49,6 +48,7 @@ def main(_: Any) -> None: + """Run example.""" # Environment environment_factory = functools.partial( diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index a207f76f3..66e86caf4 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -19,7 +19,6 @@ from typing import Any import launchpad as lp -import sonnet as snt from absl import app, flags from mava.components.tf.modules.exploration import LinearExplorationScheduler diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index 9a006c546..d9e007140 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -103,7 +103,7 @@ def get_stats(self) -> Dict: class MADQNFeedForwardExecutor(executors.FeedForwardExecutor, DQNExecutor): """A feed-forward executor for MADQN like systems. - An executor based on a feed-forward epsilon-greedy policy for + An executor based on a feed-forward epsilon-greedy policy for each agent in the system. """ @@ -288,7 +288,7 @@ def update(self, wait: bool = False) -> None: class MADQNRecurrentExecutor(executors.RecurrentExecutor, DQNExecutor): """A recurrent executor for MADQN like systems. - An executor based on a recurrent epsilon-greedy policy + An executor based on a recurrent epsilon-greedy policy for each agent in the system. """ @@ -392,8 +392,9 @@ def _policy( @tf.function def _select_actions( - self, observations: Dict[str, types.NestedArray], - states: Dict[str, types.NestedArray] + self, + observations: Dict[str, types.NestedArray], + states: Dict[str, types.NestedArray], ) -> types.NestedArray: """The part of select_action that we can do inside tf.function""" actions: Dict = {} diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 080e52579..6718f4e3a 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -186,7 +186,7 @@ def __init__( # noqa happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - learning_rate_scheduler_fn: an optional learning rate scheduler for + learning_rate_scheduler_fn: an optional learning rate scheduler for the value function optimiser. """ diff --git a/mava/systems/tf/madqn/training.py b/mava/systems/tf/madqn/training.py index 595de2e88..bbee3c59a 100644 --- a/mava/systems/tf/madqn/training.py +++ b/mava/systems/tf/madqn/training.py @@ -559,7 +559,7 @@ def _transform_observations( dims, ) - # This stop_gradient prevents gradients to propagate into + # This stop_gradient prevents gradients to propagate into # the target observation network. obs_target_trans[agent] = tree.map_structure( tf.stop_gradient, obs_target_trans[agent] diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 822216612..8bf848748 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -39,8 +39,7 @@ class ValueDecomposition(MADQN): """Value Decomposition systems. - - + Inherits from recurrent MADQN. """ @@ -212,8 +211,8 @@ def __init__( learning_rate_scheduler_fn=learning_rate_scheduler_fn, ) - # NOTE Users can either pass in their own mixer or - # use one of the pre-built ones by passing in a + # NOTE Users can either pass in their own mixer or + # use one of the pre-built ones by passing in a # string "qmix" or "vdn". if isinstance(mixer, str): if mixer == "qmix": diff --git a/mava/systems/tf/value_decomposition/training.py b/mava/systems/tf/value_decomposition/training.py index eb9b9a647..b3d076921 100644 --- a/mava/systems/tf/value_decomposition/training.py +++ b/mava/systems/tf/value_decomposition/training.py @@ -207,7 +207,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: - + obs_trans, target_obs_trans = self._transform_observations(observations) # Lists for stacking tensors later @@ -268,13 +268,13 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: max_action_q_value_all_agents, states=global_env_state ) - # NOTE Weassume team reward is just the mean + # NOTE Weassume team reward is just the mean # over agents indevidual rewards reward_all_agents = tf.reduce_mean( reward_all_agents, axis=-1, keepdims=True ) # NOTE We assume all agents have the same env discount since - # it is a team game. + # it is a team game. env_discount_all_agents = tf.reduce_mean( env_discount_all_agents, axis=-1, keepdims=True ) From c5bab208da1a36463bced1bbd05fe28139428de3 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Wed, 2 Feb 2022 14:25:52 +0200 Subject: [PATCH 47/56] fix: Updated dockerfile for missing updates. --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 6ccd7485c..3476e89f2 100755 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG record # Ensure no installs try launch interactive screen ARG DEBIAN_FRONTEND=noninteractive # Update packages -RUN apt-get update -y && apt-get install -y python3-pip && apt-get install -y python3-venv +RUN apt-get update --fix-missing -y && apt-get install -y python3-pip && apt-get install -y python3-venv # Update python path RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.8 10 &&\ rm -rf /root/.cache && apt-get clean From d77b2db73fa4774568d6ddab849a4262c2cacf6d Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Mon, 7 Feb 2022 09:14:33 +0200 Subject: [PATCH 48/56] Added random seed back to madqn system and executors --- mava/systems/tf/madqn/builder.py | 7 ++++++- mava/systems/tf/madqn/system.py | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index d58fb036d..7c9616a66 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -426,6 +426,7 @@ def make_executor( adder: Optional[adders.ReverbParallelAdder] = None, variable_source: Optional[MavaVariableSource] = None, evaluator: bool = False, + seed: Optional[int] = None, ) -> core.Executor: """Create an executor instance. @@ -439,6 +440,7 @@ def make_executor( Defaults to None. evaluator: boolean indicator if the executor is used for for evaluation only. + seed: seed for reproducible sampling. Returns: system executor, a collection of agents making up the part @@ -486,7 +488,10 @@ def make_executor( # Pass scheduler and initialize action selectors action_selectors_with_scheduler = initialize_epsilon_schedulers( - exploration_schedules, networks["selectors"], self._config.agent_net_keys + exploration_schedules, + networks["selectors"], + self._config.agent_net_keys, + seed=seed, ) # Create the actor which defines how we take actions. diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 6718f4e3a..f26397fb0 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -102,6 +102,7 @@ def __init__( # noqa termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + seed: Optional[int] = None, ): """Initialise the system. @@ -188,6 +189,9 @@ def __init__( # noqa evaluator_interval = {"executor_episodes": 100}. learning_rate_scheduler_fn: an optional learning rate scheduler for the value function optimiser. + seed: seed for reproducible sampling (used for epsilon + greedy action selection). + """ if not environment_spec: @@ -359,6 +363,7 @@ def __init__( # noqa self._eval_loop_fn = eval_loop_fn self._eval_loop_fn_kwargs = eval_loop_fn_kwargs self._evaluator_interval = evaluator_interval + self._seed = seed extra_specs = {} if issubclass(executor_fn, executors.RecurrentExecutor): @@ -499,6 +504,7 @@ def executor( adder=self._builder.make_adder(replay), variable_source=variable_source, evaluator=False, + seed=self._seed, ) # TODO (Arnu): figure out why factory function are giving type errors @@ -553,6 +559,7 @@ def evaluator( for agent in self._environment_spec.get_agent_ids() }, variable_source=variable_source, + seed=self._seed, ) # Make the environment. From f4e193d5582eefb263d78a51563066720a151fa9 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 7 Feb 2022 12:32:06 +0200 Subject: [PATCH 49/56] fix: main README system implementation table --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 10576c793..f6ca844b0 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ To read more about the motivation behind Mava, please see our [blog post][blog],
-👋 **UPDATE**: The team has been hard at work over the past few months to improve Mava's systems performance, stability and robustness. These efforts include extensively benchmarking system implementations, fixing bugs and profiling performance and speed. The culmination of this work will be reflected in our next stable release. However, during this period, we have learned a lot about what works and what doesn't. In particular, our current base system design allows for a decent amount of flexibility but quickly becomes difficult to maintain with growing signatures and system constructors as additional modules get added. Our class designs are also overly reliant on wrappers and inheritance which do not scale as well as we would like with increases in system complexity. Furthermore, our original motivation for choosing Tensorflow 2 (TF2) as our deep learning backend was to align with Acme's large repository of RL abstractions and tools for TF2. These were very useful for initially building our systems. But since then, we have found TF2 less performant and flexible than we desire given alternative frameworks. Acme has also affirmed their support of Jax underlying much of the DeepMind RL ecosystem. Therefore, in the coming months, following our stable release, **we plan to rollout a more modular and flexible build system specifically for Jax-based systems.** Please note that all TF2-based systems using the old build system will be maintained during the rollout. However, once a stable Jax release has been made with the new build system, Mava will only support a single DL backend, namely Jax, and we will begin to deprecate all TF2 systems and building support. That said, we will make sure to communicate clearly and often during the migration from TF2 to Jax. +👋 **UPDATE**: The team has been hard at work over the past few months to improve Mava's systems performance, stability and robustness. These efforts include extensively benchmarking system implementations, fixing bugs and profiling performance and speed. The culmination of this work will be reflected in our next stable release. However, during this period, we have learned a lot about what works and what doesn't. In particular, our current base system design allows for a decent amount of flexibility but quickly becomes difficult to maintain with growing signatures and system constructors as additional modules get added. Our class designs are also overly reliant on wrappers and inheritance which do not scale as well as we would like with increases in system complexity. Furthermore, our original motivation for choosing Tensorflow 2 (TF2) as our deep learning backend was to align with Acme's large repository of RL abstractions and tools for TF2. These were very useful for initially building our systems. But since then, we have found TF2 less performant and flexible than we desire given alternative frameworks. Acme has also affirmed their support of Jax underlying much of the DeepMind RL ecosystem. Therefore, in the coming months, following our stable release, **we plan to rollout a more modular and flexible build system specifically for Jax-based systems.** Please note that all TF2-based systems using the old build system will be maintained during the rollout. However, once a stable Jax release has been made with the new build system, Mava will only support a single DL backend, namely Jax, and we will begin to deprecate all TF2 systems and building support. That said, we will make sure to communicate clearly and often during the migration from TF2 to Jax.
@@ -65,12 +65,12 @@ For details on how to add your own environment, see [here](https://github.com/in | **Name** | **Recurrent** | **Continuous** | **Discrete** | **Centralised training** | **Multi Processing** | | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------- | -| MADQN | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| MADQN | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :heavy_check_mark: | | MADDPG | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | MAD4PG | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | MAPPO | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| VDN | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| QMIX | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| VDN | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| QMIX | :heavy_check_mark: | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | As we develop Mava further, we aim to have all systems well tested on a wide variety of environments. From 746722e455fd90291fb6ce945072a04a3df020ad Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 7 Feb 2022 12:50:02 +0200 Subject: [PATCH 50/56] fix: mad4pg docstrings --- mava/systems/tf/mad4pg/training.py | 309 +++++++++++++++++++++++------ mava/systems/tf/madqn/builder.py | 10 +- 2 files changed, 255 insertions(+), 64 deletions(-) diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index aeed32a86..3237fa245 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -76,6 +76,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise MAD4PG trainer + Args: agents: agent ids, e.g. "agent_0". agent_types: agent types, e.g. "speaker" or "listener". @@ -85,10 +86,8 @@ def __init__( each agent in the system. target_policy_networks: target policy networks. target_critic_networks: target critic networks. - policy_optimizer: - optimizer(s) for updating policy networks. - critic_optimizer: - optimizer for updating critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. discount: discount factor for TD updates. target_averaging: whether to use polyak averaging for target network updates. @@ -252,7 +251,41 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise the decentralised MAD4PG trainer.""" + """Initialise decentralised MAD4PG trainer + + Args: + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. + """ super().__init__( agents=agents, @@ -306,7 +339,41 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise the centralised MAD4PG trainer.""" + """Initialise centralised MAD4PG trainer + + Args: + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. + """ super().__init__( agents=agents, @@ -360,7 +427,41 @@ def __init__( logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Initialise the state-based MAD4PG trainer.""" + """Initialise state-based MAD4PG trainer + + Args: + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. + """ super().__init__( agents=agents, @@ -390,7 +491,7 @@ def __init__( class MAD4PGBaseRecurrentTrainer(MADDPGBaseRecurrentTrainer): """Recurrent MAD4PG trainer. - This is the trainer component of a MADDPG system. IE it takes a dataset as input + This is the trainer component of a MAD4PG system. IE it takes a dataset as input and implements update functionality to learn from this dataset. """ @@ -430,10 +531,8 @@ def __init__( each agent in the system. target_policy_networks: target policy networks. target_critic_networks: target critic networks. - policy_optimizer: - optimizer(s) for updating policy networks. - critic_optimizer: - optimizer for updating critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. discount: discount factor for TD updates. target_averaging: whether to use polyak averaging for target network updates. @@ -445,6 +544,7 @@ def __init__( extraction from raw observation. target_observation_networks: target observation network. + bootstrap_n: number of timestepsto use for bootstrapping. variable_client: The client used to manage the variables. counts: step counter object. agent_net_keys: specifies what network each agent uses. @@ -652,31 +752,41 @@ def __init__( bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Init trainer. + """Initialise Recurrent MAD4PG trainer Args: - agents: [description] - agent_types: [description] - policy_networks: [description] - critic_networks: [description] - target_policy_networks: [description] - target_critic_networks: [description] - policy_optimizer: [description] - critic_optimizer: [description] - discount: [description] - target_averaging: [description] - target_update_period: [description] - target_update_rate: [description] - dataset: [description] - observation_networks: [description] - target_observation_networks: [description] - variable_client: [description] - counts: [description] - agent_net_keys: [description] - max_gradient_norm: [description]. Defaults to None. - logger: [description]. Defaults to None. - bootstrap_n: [description]. Defaults to 10. - learning_rate_scheduler_fn: [description]. Defaults to None. + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + bootstrap_n: number of timestepsto use for bootstrapping. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. """ super().__init__( @@ -735,7 +845,42 @@ def __init__( bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Init trainer.""" + """Initialise Recurrent MAD4PG trainer + + Args: + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + bootstrap_n: number of timestepsto use for bootstrapping. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. + """ super().__init__( agents=agents, @@ -793,7 +938,42 @@ def __init__( bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): - """Init trainer.""" + """Initialise Recurrent MAD4PG trainer + + Args: + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + bootstrap_n: number of timestepsto use for bootstrapping. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. + """ super().__init__( agents=agents, @@ -850,30 +1030,41 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, ): - """Initialise trainer. + """Initialise State-Based Recurrent MAD4PG trainer Args: - agents: [description] - agent_types: [description] - policy_networks: [description] - critic_networks: [description] - target_policy_networks: [description] - target_critic_networks: [description] - policy_optimizer: [description] - critic_optimizer: [description] - discount: [description] - target_averaging: [description] - target_update_period: [description] - target_update_rate: [description] - dataset: [description] - observation_networks: [description] - target_observation_networks: [description] - variable_client: [description] - counts: [description] - agent_net_keys: [description] - max_gradient_norm: [description]. Defaults to None. - logger: [description]. Defaults to None. - bootstrap_n: [description]. Defaults to 10. + agents: agent ids, e.g. "agent_0". + agent_types: agent types, e.g. "speaker" or "listener". + policy_networks: policy networks for each agent in + the system. + critic_networks: critic network(s), shared or for + each agent in the system. + target_policy_networks: target policy networks. + target_critic_networks: target critic networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. + discount: discount factor for TD updates. + target_averaging: whether to use polyak averaging for target network + updates. + target_update_period: number of steps before target networks are + updated. + target_update_rate: update rate when using averaging. + dataset: training dataset. + observation_networks: network for feature + extraction from raw observation. + target_observation_networks: target observation + network. + bootstrap_n: number of timestepsto use for bootstrapping. + variable_client: The client used to manage the variables. + counts: step counter object. + agent_net_keys: specifies what network each agent uses. + max_gradient_norm: maximum allowed norm for gradients + before clipping is applied. + logger: logger object for logging trainer + statistics. + learning_rate_scheduler_fn: dict with two functions (one for the policy and + one for the critic optimizer), that takes in a trainer step t and + returns the current learning rate. """ super().__init__( diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index 7c9616a66..a4f10fdd1 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -48,7 +48,7 @@ @dataclasses.dataclass class MADQNConfig: - """Configuration options for the MADDPG system. + """Configuration options for the MADQN system. Args: environment_spec: description of the action and observation spaces etc. for @@ -145,7 +145,7 @@ def __init__( executor_fn: Type[core.Executor] = MADQNFeedForwardExecutor, extra_specs: Dict[str, Any] = {}, ): - """Initialise the system. + """Initialise the builder. Args: config: system configuration specifying hyperparameters and @@ -169,11 +169,11 @@ def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any """Convert specs. Args: - spec: [description] - num_networks: [description] + spec: agent specs + num_networks: the number of networks Returns: - Dict[str, Any]: converted specs + converted specs """ if type(spec) is not dict: return spec From b77f08c8b0eaeb5f264132fb44dc030fb208efe0 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 7 Feb 2022 14:39:29 +0200 Subject: [PATCH 51/56] fix: examples README --- examples/README.md | 49 +++++++++++----------------------------------- 1 file changed, 11 insertions(+), 38 deletions(-) diff --git a/examples/README.md b/examples/README.md index ecb5f3199..67cdf910c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -62,29 +62,16 @@ We also include a number of systems running on discrete action space environment - *Feedforward* - [decentralised][debug_madqn_ff_dec], [decentralised lr scheduling][debug_madqn_ff_dec_lr_schedule] (***using lr schedule***), [decentralised custom lr scheduling][debug_madqn_ff_dec_custom_lr_schedule] (***using custom lr schedule***) and [decentralised custom epsilon decay scheduling][debug_madqn_ff_dec_custom_eps_schedule] (***using configurable epsilon scheduling***). - *Recurrent* - - [decentralised][debug_madqn_rec_dec] and [decentralised with coms][debug_madqn_rec_dec_coms] (***using a system with communication***). - - - **QMIX**: - a QMIX system running on the discrete action space simple_spread MPE environment. - - *Feedforward* [decentralised][debug_qmix_ff_dec]. + - [decentralised][debug_madqn_rec_dec]. - **VDN**: a VDN system running on the discrete action space simple_spread MPE environment. - - *Feedforward* [decentralised][debug_vdn_ff_dec]. - - - **DIAL**: - a DIAL system running on the discrete action space simple_spread MPE environment. - - *Recurrent* [decentralised][debug_dial_rec_dec]. - -### Debugging Environment - Switch -- **DIAL**: - a DIAL system running on the discrete custom SwitchGame environment. - - *Recurrent* [decentralised][debug_switch_dial_rec_dec]. + - *Recurrent* [centralised][debug_vdn_rec_cen]. ### PettingZoo - Multi-Agent Atari - **MADQN**: a MADQN system running on the two-player competitive Atari Pong environment. - - *Feedforward* [decentralised][pz_madqn_pong_ff_dec]. + - *Recurrent* [decentralised][pz_madqn_pong_ff_dec]. ### PettingZoo - Multi-Agent Particle Environment - **MADDPG**: @@ -101,15 +88,15 @@ We also include a number of systems running on discrete action space environment - *Feedforward* - [decentralised][smac_madqn_ff_dec]. - *Recurrent* - - [decentralised with custom agent networks][smac_madqn_rec_dec_custom_agents] (***using custom agent networks***). + - [decentralised][smac_madqn_rec_dec]. - **QMIX**: a QMIX system running on the SMAC environment. - - *Feedforward* [decentralised][smac_qmix_ff_dec]. + - *Recurrent* [centralised][smac_qmix_rec_cen]. - **VDN**: a VDN system running on the SMAC environment. - - *Feedforward* [decentralised][smac_vdn_ff_dec] and [decentralised record agents][smac_vdn_ff_dec_record]. + - *Recurrent* [centralised][smac_vdn_rec_cen]. ### OpenSpiel - Tic Tac Toe - **MADQN**: @@ -159,33 +146,19 @@ We also include a number of systems running on discrete action space environment [debug_madqn_ff_dec_custom_lr_schedule]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_custom_lr_schedule.py [debug_madqn_ff_dec_custom_eps_schedule]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/feedforward/decentralised/run_madqn_configurable_epsilon.py [debug_madqn_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/recurrent/decentralised/run_madqn.py -[debug_madqn_rec_dec_coms]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/recurrent/decentralised/run_madqn_with_coms.py - -[debug_qmix_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/feedforward/decentralised/run_qmix.py - -[debug_vdn_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/feedforward/decentralised/run_vdn.py - -[debug_dial_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/recurrent/decentralised/run_dial.py +[debug_vdn_rec_cen]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/simple_spread/recurrent/centralised/run_vdn.py -[debug_switch_dial_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/debugging/switch/recurrent/decentralised/run_dial.py - - -[pz_madqn_pong_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/petting_zoo/atari/pong/feedforward/decentralised/run_madqn.py +[pz_madqn_pong_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/petting_zoo/atari/pong/recurrent/centralised/run_madqn.py [pz_maddpg_mpe_ssl_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/petting_zoo/mpe/simple_speaker_listener/feedforward/decentralised/run_maddpg.py [pz_maddpg_mpe_ss_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/petting_zoo/mpe/simple_spread/feedforward/decentralised/run_maddpg.py +[smac_madqn_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/recurrent/decentralised/run_madqn.py -[smac_madqn_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/feedforward/decentralised/run_madqn.py - -[smac_madqn_rec_dec_custom_agents]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/recurrent/decentralised/run_madqn.py - -[smac_qmix_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/feedforward/decentralised/run_qmix.py - -[smac_vdn_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/feedforward/decentralised/run_vdn.py +[smac_qmix_rec_cen]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/recurrent/centralised/run_qmix.py -[smac_vdn_ff_dec_record]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/feedforward/decentralised/run_vdn_record.py +[smac_vdn_rec_cen]: https://github.com/instadeepai/Mava/blob/develop/examples/smac/recurrent/centralised/run_vdn.py [openspiel_madqn_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/openspiel/tic_tac_toe/feedforward/decentralised/run_madqn.py From 9db9baf964655a14d978616ef02d03ef756d7336 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Mon, 7 Feb 2022 14:48:06 +0200 Subject: [PATCH 52/56] fix: evaluator interval 2 -> 2000 on SMAC --- examples/smac/recurrent/centralised/run_qmix.py | 2 +- examples/smac/recurrent/centralised/run_vdn.py | 2 +- examples/smac/recurrent/decentralised/run_madqn.py | 2 +- .../smac/recurrent/decentralised/run_madqn_scale_trainers.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index 124184023..b114b7e31 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -95,7 +95,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_episodes": 2000}, ).build() # Only the trainer should use the GPU (if available) diff --git a/examples/smac/recurrent/centralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py index f79a5700d..b53683cde 100644 --- a/examples/smac/recurrent/centralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -92,7 +92,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_episodes": 2000}, ).build() # Only the trainer should use the GPU (if available) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index d57d21610..abedcfb03 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -94,7 +94,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_episodes": 2000}, trainer_fn=madqn.MADQNRecurrentTrainer, executor_fn=madqn.MADQNRecurrentExecutor, ).build() diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index 66e86caf4..ddc23626e 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -92,7 +92,7 @@ def main(_: Any) -> None: min_replay_size=32, batch_size=32, samples_per_insert=4, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_episodes": 2000}, checkpoint_subpath=checkpoint_dir, ).build() From e3991c76e0f3e84ff990ae46c68e8eff0408706c Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Tue, 8 Feb 2022 09:45:03 +0200 Subject: [PATCH 53/56] Fixed import issues in env_preprocess_wrappers --- mava/wrappers/env_preprocess_wrappers.py | 298 +++++++++++++---------- 1 file changed, 169 insertions(+), 129 deletions(-) diff --git a/mava/wrappers/env_preprocess_wrappers.py b/mava/wrappers/env_preprocess_wrappers.py index 5c4cb45af..474aaacce 100644 --- a/mava/wrappers/env_preprocess_wrappers.py +++ b/mava/wrappers/env_preprocess_wrappers.py @@ -18,36 +18,59 @@ import dm_env import gym import numpy as np -from pettingzoo.utils import BaseParallelWraper -from supersuit.utils.base_aec_wrapper import BaseWrapper from mava.types import OLT, Action, Observation, Reward from mava.utils.wrapper_utils import RunningMeanStd from mava.wrappers.env_wrappers import ParallelEnvWrapper, SequentialEnvWrapper -# Prevent circular import issue. -if TYPE_CHECKING: - from mava.wrappers.pettingzoo import ( - PettingZooAECEnvWrapper, - PettingZooParallelEnvWrapper, - ) +try: + import supersuit + + _has_supersuit = True +except ModuleNotFoundError: + _has_supersuit = False + pass + + +try: + import pettingzoo # noqa: F401 + + _has_petting_zoo = True +except ModuleNotFoundError: + _has_petting_zoo = False + pass + +if _has_petting_zoo: + from pettingzoo.utils import BaseParallelWraper + + # Prevent circular import issue. + if TYPE_CHECKING: + from mava.wrappers.pettingzoo import ( + PettingZooAECEnvWrapper, + PettingZooParallelEnvWrapper, + ) PettingZooEnv = Union["PettingZooAECEnvWrapper", "PettingZooParallelEnvWrapper"] +if _has_supersuit: + from supersuit.utils.base_aec_wrapper import BaseWrapper + + # TODO(Kale-ab): Make wrapper more general # Should Works across any SequentialEnvWrapper or ParallelEnvWrapper. """ GYM Preprocess Wrappers. Other gym preprocess wrappers: - https://github.com/PettingZoo-Team/SuperSuit/blob/1f02289e8f51082aa50a413b34700b67042410c6/supersuit/gym_wrappers.py - https://github.com/openai/gym/tree/master/gym/wrappers + https://github.com/PettingZoo-Team/SuperSuit/blob/1f02289e8f51082aa50a413b34700b67042410c6/supersuit/gym_wrappers.py # noqa: E501 + https://github.com/openai/gym/tree/master/gym/wrappers # noqa: E501 """ class StandardizeObservationGym(gym.ObservationWrapper): """ - Standardize observations + Standardize observations. + Ensures mean of 0 and standard deviation of 1 (unit variance) for obs. From https://github.com/ikostrikov/pytorch-a3c/blob/e898f7514a03de73a2bf01e7b0f17a6f93963389/envs.py # noqa: E501 """ @@ -144,55 +167,65 @@ def _get_updated_observation( return (observation - unbiased_mean) / (unbiased_std + 1e-8) -class StandardizeObservationSequential(BaseWrapper, StandardizeObservation): - """Standardize Obs in Sequential Env""" +if _has_supersuit: - def __init__( - self, env: PettingZooEnv, load_params: Dict = None, alpha: float = 0.999 - ) -> None: - BaseWrapper.__init__(self, env) - StandardizeObservation.__init__(self, env, load_params, alpha) - - def _modify_observation(self, agent: str, observation: Observation) -> Observation: - self._internal_state[agent]["num_steps"] = ( - int(self._internal_state[agent]["num_steps"]) + 1 - ) - return self._get_updated_observation(agent, observation) + class StandardizeObservationSequential(BaseWrapper, StandardizeObservation): + """Standardize Obs in Sequential Env""" - def _modify_action(self, agent: str, action: Action) -> Action: - return action + def __init__( + self, env: PettingZooEnv, load_params: Dict = None, alpha: float = 0.999 + ) -> None: + BaseWrapper.__init__(self, env) + StandardizeObservation.__init__(self, env, load_params, alpha) + def _modify_observation( + self, agent: str, observation: Observation + ) -> Observation: + self._internal_state[agent]["num_steps"] = ( + int(self._internal_state[agent]["num_steps"]) + 1 + ) + return self._get_updated_observation(agent, observation) -class StandardizeObservationParallel(BaseParallelWraper, StandardizeObservation): - """Standardize Obs in Parallel Env""" + def _modify_action(self, agent: str, action: Action) -> Action: + return action - def __init__( - self, env: PettingZooEnv, load_params: Dict = None, alpha: float = 0.999 - ) -> None: - BaseParallelWraper.__init__(self, env) - StandardizeObservation.__init__(self, env, load_params, alpha) - - def _modify_observation(self, agent: str, observation: Observation) -> Observation: - self._internal_state[agent]["num_steps"] = ( - int(self._internal_state[agent]["num_steps"]) + 1 - ) - return self._get_updated_observation(agent, observation) - def _modify_action(self, action: Action) -> Action: - return action +if _has_petting_zoo: + + class StandardizeObservationParallel(BaseParallelWraper, StandardizeObservation): + """Standardize Obs in Parallel Env""" + + def __init__( + self, env: PettingZooEnv, load_params: Dict = None, alpha: float = 0.999 + ) -> None: + BaseParallelWraper.__init__(self, env) + StandardizeObservation.__init__(self, env, load_params, alpha) + + def _modify_observation( + self, agent: str, observation: Observation + ) -> Observation: + self._internal_state[agent]["num_steps"] = ( + int(self._internal_state[agent]["num_steps"]) + 1 + ) + return self._get_updated_observation(agent, observation) + + def _modify_action(self, action: Action) -> Action: + return action - def reset(self) -> Dict: - obss = super().reset() - return { - agent: self._modify_observation(agent, obs) for agent, obs in obss.items() - } + def reset(self) -> Dict: + obss = super().reset() + return { + agent: self._modify_observation(agent, obs) + for agent, obs in obss.items() + } - def step(self, actions: Action) -> Any: - obss, rew, done, info = super().step(actions) - obss = { - agent: self._modify_observation(agent, obs) for agent, obs in obss.items() - } - return obss, rew, done, info + def step(self, actions: Action) -> Any: + obss, rew, done, info = super().step(actions) + obss = { + agent: self._modify_observation(agent, obs) + for agent, obs in obss.items() + } + return obss, rew, done, info class StandardizeReward: @@ -278,89 +311,96 @@ def _get_updated_reward( return reward -class StandardizeRewardSequential(BaseWrapper, StandardizeReward): - def __init__( - self, - env: SequentialEnvWrapper, - load_params: Dict = None, - lower_bound: float = -10.0, - upper_bound: float = 10.0, - alpha: float = 0.999, - ) -> None: - BaseWrapper.__init__(self, env) - StandardizeReward.__init__( - self, env, load_params, lower_bound, upper_bound, alpha - ) +if _has_supersuit: + + class StandardizeRewardSequential(BaseWrapper, StandardizeReward): + def __init__( + self, + env: SequentialEnvWrapper, + load_params: Dict = None, + lower_bound: float = -10.0, + upper_bound: float = 10.0, + alpha: float = 0.999, + ) -> None: + BaseWrapper.__init__(self, env) + StandardizeReward.__init__( + self, env, load_params, lower_bound, upper_bound, alpha + ) - def reset(self) -> None: - # Reset returns, but not running scores. - for stats in self._internal_state.values(): - stats["return"] = 0 - - super().reset() - self.rewards = { - agent: self._get_updated_reward(agent, reward) - for agent, reward in self.rewards.items() # type: ignore - } - self.__cumulative_rewards = {a: 0 for a in self.agents} - self._accumulate_rewards() - - def step(self, action: np.ndarray) -> None: - agent = self.env.agent_selection # type: ignore - super().step(action) - self.rewards = { - agent: self._get_updated_reward(agent, reward) - for agent, reward in self.rewards.items() - } - self.__cumulative_rewards[agent] = 0 - self._cumulative_rewards = self.__cumulative_rewards - self._accumulate_rewards() - - def _modify_observation(self, agent: str, observation: Observation) -> Observation: - return observation - - def _modify_action(self, agent: str, action: Action) -> Action: - return action - - -class StandardizeRewardParallel( - BaseParallelWraper, - StandardizeReward, -): - def __init__( - self, - env: ParallelEnvWrapper, - load_params: Dict = None, - lower_bound: float = -10.0, - upper_bound: float = 10.0, - alpha: float = 0.999, - ) -> None: - BaseParallelWraper.__init__(self, env) - StandardizeReward.__init__( - self, env, load_params, lower_bound, upper_bound, alpha - ) + def reset(self) -> None: + # Reset returns, but not running scores. + for stats in self._internal_state.values(): + stats["return"] = 0 + + super().reset() + self.rewards = { + agent: self._get_updated_reward(agent, reward) + for agent, reward in self.rewards.items() # type: ignore + } + self.__cumulative_rewards = {a: 0 for a in self.agents} + self._accumulate_rewards() + + def step(self, action: np.ndarray) -> None: + agent = self.env.agent_selection # type: ignore + super().step(action) + self.rewards = { + agent: self._get_updated_reward(agent, reward) + for agent, reward in self.rewards.items() + } + self.__cumulative_rewards[agent] = 0 + self._cumulative_rewards = self.__cumulative_rewards + self._accumulate_rewards() + + def _modify_observation( + self, agent: str, observation: Observation + ) -> Observation: + return observation + + def _modify_action(self, agent: str, action: Action) -> Action: + return action - def reset(self) -> Observation: - # Reset returns, but not running scores. - for stats in self._internal_state.values(): - stats["return"] = 0 - obs = self.env.reset() # type: ignore - self.agents = self.env.agents # type: ignore - return obs +if _has_petting_zoo: - def step(self, actions: Dict) -> Any: - obs, rew, done, info = super().step(actions) - rew = { - agent: self._get_updated_reward(agent, rew) for agent, rew in rew.items() - } - return obs, rew, done, info + class StandardizeRewardParallel( + BaseParallelWraper, + StandardizeReward, + ): + def __init__( + self, + env: ParallelEnvWrapper, + load_params: Dict = None, + lower_bound: float = -10.0, + upper_bound: float = 10.0, + alpha: float = 0.999, + ) -> None: + BaseParallelWraper.__init__(self, env) + StandardizeReward.__init__( + self, env, load_params, lower_bound, upper_bound, alpha + ) + + def reset(self) -> Observation: + # Reset returns, but not running scores. + for stats in self._internal_state.values(): + stats["return"] = 0 + + obs = self.env.reset() # type: ignore + self.agents = self.env.agents # type: ignore + return obs + + def step(self, actions: Dict) -> Any: + obs, rew, done, info = super().step(actions) + rew = { + agent: self._get_updated_reward(agent, rew) + for agent, rew in rew.items() + } + return obs, rew, done, info - def _modify_observation(self, observation: Observation) -> Observation: - return observation + def _modify_observation(self, observation: Observation) -> Observation: + return observation - def _modify_action(self, action: Action) -> Action: - return action + def _modify_action(self, action: Action) -> Action: + return action class ConcatAgentIdToObservation: From c9d8a0ea86cbb2ffe04fcb8ea8deb8018f8261e1 Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 9 Feb 2022 14:17:25 +0200 Subject: [PATCH 54/56] fix: EpsilonTimestepSchedulers in scaling MADQN --- .../smac/recurrent/centralised/run_qmix.py | 6 +- .../smac/recurrent/centralised/run_vdn.py | 6 +- .../smac/recurrent/decentralised/run_madqn.py | 6 +- .../decentralised/run_madqn_scale_trainers.py | 8 +- mava/environment_loop.py | 46 ++- mava/systems/tf/executors.py | 312 ++---------------- 6 files changed, 78 insertions(+), 306 deletions(-) diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index b114b7e31..96262bcf9 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -24,7 +24,7 @@ from absl import app, flags from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationScheduler, + LinearExplorationTimestepScheduler, ) from mava.systems.tf import value_decomposition from mava.utils import lp_utils @@ -79,8 +79,8 @@ def main(_: Any) -> None: mixer="qmix", logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=4e-6 + exploration_scheduler_fn=LinearExplorationTimestepScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 diff --git a/examples/smac/recurrent/centralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py index b53683cde..4e284f651 100644 --- a/examples/smac/recurrent/centralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -23,7 +23,7 @@ from absl import app, flags from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationScheduler, + LinearExplorationTimestepScheduler, ) from mava.systems.tf import value_decomposition from mava.utils import lp_utils @@ -78,8 +78,8 @@ def main(_: Any) -> None: mixer="vdn", logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 + exploration_scheduler_fn=LinearExplorationTimestepScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index abedcfb03..840bd1025 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -24,7 +24,7 @@ from absl import app, flags from mava.components.tf.modules.exploration.exploration_scheduling import ( - LinearExplorationScheduler, + LinearExplorationTimestepScheduler, ) from mava.systems.tf import madqn from mava.utils import lp_utils @@ -78,8 +78,8 @@ def main(_: Any) -> None: network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, epsilon_min=0.05, epsilon_decay=5e-6 + exploration_scheduler_fn=LinearExplorationTimestepScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 ), optimizer=snt.optimizers.RMSProp( learning_rate=0.0005, epsilon=0.00001, decay=0.99 diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index ddc23626e..fb8df27f1 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -21,7 +21,7 @@ import launchpad as lp from absl import app, flags -from mava.components.tf.modules.exploration import LinearExplorationScheduler +from mava.components.tf.modules.exploration import LinearExplorationTimestepScheduler from mava.systems.tf import madqn from mava.utils import enums, lp_utils from mava.utils.enums import ArchitectureType @@ -78,10 +78,8 @@ def main(_: Any) -> None: network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - exploration_scheduler_fn=LinearExplorationScheduler( - epsilon_start=1.0, - epsilon_min=0.05, - epsilon_decay=5e-6, + exploration_scheduler_fn=LinearExplorationTimestepScheduler( + epsilon_start=1.0, epsilon_min=0.05, epsilon_decay_steps=50000 ), shared_weights=False, trainer_networks=enums.Trainer.one_trainer_per_network, diff --git a/mava/environment_loop.py b/mava/environment_loop.py index 875890d50..dac852fb0 100644 --- a/mava/environment_loop.py +++ b/mava/environment_loop.py @@ -35,6 +35,7 @@ class SequentialEnvironmentLoop(acme.core.Worker): """A Sequential MARL environment loop. + This takes `Environment` and `Executor` instances and coordinates their interaction. Executors are updated if `should_update=True`. This can be used as: loop = EnvironmentLoop(environment, executor) @@ -58,6 +59,16 @@ def __init__( should_update: bool = True, label: str = "sequential_environment_loop", ): + """Sequential environment loop + + Args: + environment: an environment + executor: a Mava executor + counter: an optional counter. Defaults to None. + logger: an optional counter. Defaults to None. + should_update: should update. Defaults to True. + label: optional label. Defaults to "sequential_environment_loop". + """ # Internalize agent and environment. self._environment = environment self._executor = executor @@ -196,9 +207,11 @@ def _collect_last_timesteps(self, timestep: dm_env.TimeStep) -> None: def run_episode(self) -> loggers.LoggingData: """Run one episode. + Each episode is a loop which interacts first with the environment to get an observation and then give that observation to the agent in order to retrieve an action. + Returns: An instance of `loggers.LoggingData`. """ @@ -264,6 +277,7 @@ def run( self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None ) -> None: """Perform the run loop. + Run the environment loop either for `num_episodes` episodes or for at least `num_steps` steps (the last episode is always run until completion, so the total number of steps may be slightly more than `num_steps`). @@ -271,9 +285,11 @@ def run( Upon termination of an episode a new episode will be started. If the number of episodes and the number of steps are not given then this will interact with the environment infinitely. + Args: num_episodes: number of episodes to run the loop for. num_steps: minimal number of steps to run the loop for. + Raises: ValueError: If both 'num_episodes' and 'num_steps' are not None. """ @@ -297,6 +313,7 @@ def should_terminate(episode_count: int, step_count: int) -> bool: class ParallelEnvironmentLoop(acme.core.Worker): """A parallel MARL environment loop. + This takes `Environment` and `Executor` instances and coordinates their interaction. Executors are updated if `should_update=True`. This can be used as: loop = EnvironmentLoop(environment, executor) @@ -320,6 +337,16 @@ def __init__( should_update: bool = True, label: str = "parallel_environment_loop", ): + """Parallel environment loop init + + Args: + environment: an environment + executor: a Mava executor + counter: an optional counter. Defaults to None. + logger: an optional counter. Defaults to None. + should_update: should update. Defaults to True. + label: optional label. Defaults to "sequential_environment_loop". + """ # Internalize agent and environment. self._environment = environment self._executor = executor @@ -349,6 +376,7 @@ def _compute_episode_statistics( pass def get_counts(self) -> counting.Counter: + """Get latest counts""" if hasattr(self._executor, "_counts"): counts = self._executor._counts else: @@ -356,6 +384,7 @@ def get_counts(self) -> counting.Counter: return counts def record_counts(self, episode_steps: int) -> counting.Counter: + """Record latest counts""" # Record counts. if hasattr(self._executor, "_counts"): loop_type = "evaluator" if self._executor._evaluator else "executor" @@ -380,9 +409,11 @@ def record_counts(self, episode_steps: int) -> counting.Counter: def run_episode(self) -> loggers.LoggingData: """Run one episode. + Each episode is a loop which interacts first with the environment to get a dictionary of observations and then give those observations to the executor in order to retrieve an action for each agent in the system. + Returns: An instance of `loggers.LoggingData`. """ @@ -443,9 +474,15 @@ def run_episode(self) -> loggers.LoggingData: episode_steps += 1 if hasattr(self._executor, "after_action_selection"): - total_steps_before_current_episode = self._counter.get_counts().get( - "executor_steps", 0 - ) + if hasattr(self._executor, "_counts"): + loop_type = "evaluator" if self._executor._evaluator else "executor" + total_steps_before_current_episode = self._executor._counts[ + f"{loop_type}_steps" + ].numpy() + else: + total_steps_before_current_episode = self._counter.get_counts().get( + "executor_steps", 0 + ) current_step_t = total_steps_before_current_episode + episode_steps self._executor.after_action_selection(current_step_t) @@ -479,6 +516,7 @@ def run( self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None ) -> None: """Perform the run loop. + Run the environment loop either for `num_episodes` episodes or for at least `num_steps` steps (the last episode is always run until completion, so the total number of steps may be slightly more than `num_steps`). @@ -486,9 +524,11 @@ def run( Upon termination of an episode a new episode will be started. If the number of episodes and the number of steps are not given then this will interact with the environment infinitely. + Args: num_episodes: number of episodes to run the loop for. num_steps: minimal number of steps to run the loop for. + Raises: ValueError: If both 'num_episodes' and 'num_steps' are not None. """ diff --git a/mava/systems/tf/executors.py b/mava/systems/tf/executors.py index 1608541f9..c0ffe5127 100644 --- a/mava/systems/tf/executors.py +++ b/mava/systems/tf/executors.py @@ -26,13 +26,13 @@ from acme.tf import variable_utils as tf2_variable_utils from mava import adders, core -from mava.components.tf.modules.communication import BaseCommunicationModule tfd = tfp.distributions class FeedForwardExecutor(core.Executor): """A generic feed-forward executor. + An executor based on a feed-forward policy for each agent in the system. """ @@ -46,14 +46,13 @@ def __init__( """Initialise the system executor Args: - policy_networks (Dict[str, snt.Module]): policy networks for each agent in + policy_networks: policy networks for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. + agent_net_keys: specifies what network each agent uses. Defaults to {}. - adder (Optional[adders.ReverbParallelAdder], optional): adder which sends + adder: adder which sends data to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. + variable_client: client to copy weights from the trainer. Defaults to None. """ # Store these for later use. @@ -94,7 +93,7 @@ def _policy( def select_action( self, agent: str, observation: types.NestedArray ) -> Union[types.NestedArray, Tuple[types.NestedArray, types.NestedArray]]: - """select an action for a single agent in the system + """Select an action for a single agent in the system Args: agent (str): agent id. @@ -120,7 +119,7 @@ def observe_first( timestep: dm_env.TimeStep, extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record first observed timestep from the environment + """Record first observed timestep from the environment Args: timestep (dm_env.TimeStep): data emitted by an environment at first step of @@ -138,7 +137,7 @@ def observe( next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record observed timestep from the environment + """Record observed timestep from the environment Args: actions (Dict[str, types.NestedArray]): system agents' actions. @@ -157,7 +156,7 @@ def select_actions( Dict[str, types.NestedArray], Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]], ]: - """select the actions for all agents in the system + """Select the actions for all agents in the system Args: observations (Dict[str, types.NestedArray]): agent observations from the @@ -180,7 +179,7 @@ def select_actions( return actions def update(self, wait: bool = False) -> None: - """update executor variables + """Update executor variables Args: wait (bool, optional): whether to stall the executor's request for new @@ -193,6 +192,7 @@ def update(self, wait: bool = False) -> None: class RecurrentExecutor(core.Executor): """A generic recurrent Executor. + An executor based on a recurrent policy for each agent in the system. """ @@ -207,15 +207,15 @@ def __init__( """Initialise the system executor Args: - policy_networks (Dict[str, snt.RNNCore]): policy networks for each agent in + policy_networks: policy networks for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. + agent_net_keys: specifies what network each agent uses. Defaults to {}. - adder (Optional[adders.ReverbParallelAdder], optional): adder which sends + adder: adder which sends data to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. - store_recurrent_state (bool, optional): boolean to store the recurrent + variable_client: client to copy weights from the trainer. + Defaults to None. + store_recurrent_state: boolean to store the recurrent network hidden state. Defaults to True. """ @@ -266,7 +266,7 @@ def _policy( return action, new_state def _update_state(self, agent: str, new_state: types.NestedArray) -> None: - """update recurrent hidden state + """Update recurrent hidden state Args: agent (str): agent id. @@ -279,7 +279,7 @@ def _update_state(self, agent: str, new_state: types.NestedArray) -> None: def select_action( self, agent: str, observation: types.NestedArray ) -> types.NestedArray: - """select an action for a single agent in the system + """Select an action for a single agent in the system Args: agent (str): agent id @@ -315,7 +315,7 @@ def observe_first( timestep: dm_env.TimeStep, extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record first observed timestep from the environment + """Record first observed timestep from the environment Args: timestep (dm_env.TimeStep): data emitted by an environment at first step of @@ -348,7 +348,7 @@ def observe( next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: - """record observed timestep from the environment + """Record observed timestep from the environment Args: actions (Dict[str, types.NestedArray]): system agents' actions. @@ -384,7 +384,7 @@ def select_actions( Dict[str, types.NestedArray], Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]], ]: - """select the actions for all agents in the system + """Select the actions for all agents in the system Args: observations (Dict[str, types.NestedArray]): agent observations from the @@ -402,7 +402,7 @@ def select_actions( return actions def update(self, wait: bool = False) -> None: - """update executor variables + """Update executor variables Args: wait (bool, optional): whether to stall the executor's request for new @@ -411,269 +411,3 @@ def update(self, wait: bool = False) -> None: if self._variable_client: self._variable_client.update(wait) - - -class RecurrentCommExecutor(RecurrentExecutor): - """Generic recurrent executor with communicate.""" - - def __init__( - self, - policy_networks: Dict[str, snt.RNNCore], - communication_module: BaseCommunicationModule, - agent_net_keys: Dict[str, str], - adder: Optional[adders.ReverbParallelAdder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - store_recurrent_state: bool = True, - ): - """Initialise the system executor - - Args: - policy_networks (Dict[str, snt.RNNCore]): policy networks for each agent in - the system. - communication_module (BaseCommunicationModule): module for enabling - communication protocols between agents. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - adder (Optional[adders.ReverbParallelAdder], optional): adder which sends - data to a replay buffer. Defaults to None. - variable_client (Optional[tf2_variable_utils.VariableClient], optional): - client to copy weights from the trainer. Defaults to None. - store_recurrent_state (bool, optional): boolean to store the recurrent - network hidden state. Defaults to True. - """ - - # Store these for later use. - - self._adder = adder - self._variable_client = variable_client - self._policy_networks = policy_networks - self._states: Dict[str, Any] = {} - self._messages: Dict[str, Any] = {} - self._store_recurrent_state = store_recurrent_state - self._communication_module = communication_module - self._agent_net_keys = agent_net_keys - - def _sample_action( - self, action_policy: types.NestedTensor, agent: str - ) -> types.NestedTensor: - """sample an action from an agent's policy distribution - - Args: - action_policy (types.NestedTensor): agent policy. - agent (str): agent id. - - Returns: - types.NestedTensor: agent action - """ - - # Sample from the policy if it is stochastic. - return ( - action_policy.sample() - if isinstance(action_policy, tfd.Distribution) - else action_policy - ) - - def _process_message( - self, observation: types.NestedTensor, message_policy: types.NestedTensor - ) -> types.NestedTensor: - """process agent messages - - Args: - observation (types.NestedTensor): observation tensor received from the - environment. - message_policy (types.NestedTensor): message policy. - - Returns: - types.NestedTensor: processed message. - """ - - # Only one agent can message at each timestep - return message_policy - - @tf.function - def _policy( - self, - agent: str, - observation: types.NestedTensor, - state: types.NestedTensor, - message: types.NestedTensor, - ) -> Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: - """Agent specific policy function - - Args: - agent (str): agent id - observation (types.NestedTensor): observation tensor received from the - environment. - state (types.NestedTensor): recurrent network state. - message (types.NestedTensor): agent message. - - Returns: - Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: action, - message and new recurrent hidden state - """ - - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_observation = tf2_utils.add_batch_dim(observation) - - # index network either on agent type or on agent id - agent_key = self._agent_net_keys[agent] - - # Compute the policy, conditioned on the observation. - (action_policy, message_policy), new_state = self._policy_networks[agent_key]( - batched_observation, state, message - ) - - action = self._sample_action(action_policy, agent) - - message = self._process_message(observation, message_policy) - - return action, message, new_state - - def observe_first( - self, - timestep: dm_env.TimeStep, - extras: Optional[Dict[str, types.NestedArray]] = {}, - ) -> None: - """record first observed timestep from the environment - - Args: - timestep (dm_env.TimeStep): data emitted by an environment at first step of - interaction. - extras (Dict[str, types.NestedArray], optional): possible extra information - to record during the first step. Defaults to {}. - - Raises: - NotImplementedError: check for extras that are 'None'. - """ - - # Re-initialize the RNN state. - for agent, _ in timestep.observation.items(): - # index network either on agent type or on agent id - agent_key = self._agent_net_keys[agent] - self._states[agent] = self._policy_networks[agent_key].initial_state(1) - self._messages[agent] = self._policy_networks[agent_key].initial_message(1) - - if self._adder is not None: - numpy_states = { - agent: tf2_utils.to_numpy_squeeze(_state) - for agent, _state in self._states.items() - } - numpy_messages = { - agent: tf2_utils.to_numpy_squeeze(_message) - for agent, _message in self._messages.items() - } - if extras is not None: - extras.update( - { - "core_states": numpy_states, - "core_messages": numpy_messages, - } - ) - self._adder.add_first(timestep, extras) - else: - raise NotImplementedError("Why is extras None?") - - def observe( - self, - actions: Dict[str, types.NestedArray], - next_timestep: dm_env.TimeStep, - next_extras: Optional[Dict[str, types.NestedArray]] = {}, - ) -> None: - """record observed timestep from the environment - - Args: - actions (Dict[str, types.NestedArray]): system agents' actions. - next_timestep (dm_env.TimeStep): data emitted by an environment during - interaction. - next_extras (Dict[str, types.NestedArray], optional): possible extra - information to record during the transition. Defaults to {}. - """ - - if not self._adder: - return - - if not self._store_recurrent_state: - if next_extras: - self._adder.add(actions, next_timestep, next_extras) - else: - self._adder.add(actions, next_timestep) - return - - numpy_states = { - agent: tf2_utils.to_numpy_squeeze(_state) - for agent, _state in self._states.items() - } - numpy_messages = { - agent: tf2_utils.to_numpy_squeeze(_message) - for agent, _message in self._messages.items() - } - - if next_extras: - next_extras.update( - { - "core_states": numpy_states, - "core_messages": numpy_messages, - } - ) - self._adder.add(actions, next_timestep, next_extras) - else: - next_extras = { - "core_states": numpy_states, - "core_messages": numpy_messages, - } - self._adder.add(actions, next_timestep, next_extras) - - def select_action( - self, - agent: str, - observation: types.NestedArray, - ) -> Tuple[types.NestedArray, types.NestedArray]: - """select an action for a single agent in the system - - Args: - agent (str): agent id - observation (types.NestedArray): observation tensor received from the - environment. - - Returns: - Tuple[types.NestedArray, types.NestedArray]: agent action. - """ - - # Initialize the RNN state if necessary. - if self._states[agent] is None: - self._states[agent] = self._networks[agent].initial_state(1) - - message_inputs = self._communication_module.process_messages(self._messages) - # Step the recurrent policy forward given the current observation and state. - policy_output, message, new_state = self._policy( - agent, - observation.observation, - self._states[agent], - message_inputs[agent], - ) - - # Bookkeeping of recurrent states for the observe method. - self._states[agent] = new_state - self._messages[agent] = message - - # Return a numpy array with squeezed out batch dimension. - return tf2_utils.to_numpy_squeeze(policy_output) - - def select_actions( - self, observations: Dict[str, types.NestedArray] - ) -> Dict[str, types.NestedArray]: - """select the actions for all agents in the system - - Args: - observations (Dict[str, types.NestedArray]): agent observations from the - environment. - - Returns: - Dict[str, types.NestedArray]: actions for all agents in the system. - """ - - actions = {} - for agent, observation in observations.items(): - actions[agent] = self._select_action(agent, observation) - # Return a numpy array with squeezed out batch dimension. - return actions From ce614b8e29fcd4adb8fe86d6f6294dda6a3163da Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Wed, 9 Feb 2022 14:25:25 +0200 Subject: [PATCH 55/56] fix: lower evaluator interval --- examples/smac/recurrent/centralised/run_qmix.py | 2 +- examples/smac/recurrent/centralised/run_vdn.py | 2 +- examples/smac/recurrent/decentralised/run_madqn.py | 2 +- .../smac/recurrent/decentralised/run_madqn_scale_trainers.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index 96262bcf9..414dd6e3b 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -95,7 +95,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2000}, + evaluator_interval={"executor_episodes": 2}, ).build() # Only the trainer should use the GPU (if available) diff --git a/examples/smac/recurrent/centralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py index 4e284f651..e0850bcfd 100644 --- a/examples/smac/recurrent/centralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -92,7 +92,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2000}, + evaluator_interval={"executor_episodes": 2}, ).build() # Only the trainer should use the GPU (if available) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 840bd1025..6dc3bf72d 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -94,7 +94,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2000}, + evaluator_interval={"executor_episodes": 2}, trainer_fn=madqn.MADQNRecurrentTrainer, executor_fn=madqn.MADQNRecurrentExecutor, ).build() diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index fb8df27f1..d285c31f0 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -90,7 +90,7 @@ def main(_: Any) -> None: min_replay_size=32, batch_size=32, samples_per_insert=4, - evaluator_interval={"executor_episodes": 2000}, + evaluator_interval={"executor_episodes": 2}, checkpoint_subpath=checkpoint_dir, ).build() From b678c4865ca1014af6cc246d046c0af896debc3f Mon Sep 17 00:00:00 2001 From: Claude Formanek Date: Fri, 11 Feb 2022 10:56:58 +0200 Subject: [PATCH 56/56] fix: evaluator interval on SMAC examples should be ever 2000 steps --- examples/smac/recurrent/centralised/run_qmix.py | 2 +- examples/smac/recurrent/centralised/run_vdn.py | 2 +- examples/smac/recurrent/decentralised/run_madqn.py | 2 +- .../smac/recurrent/decentralised/run_madqn_scale_trainers.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/smac/recurrent/centralised/run_qmix.py b/examples/smac/recurrent/centralised/run_qmix.py index 414dd6e3b..8231aa745 100644 --- a/examples/smac/recurrent/centralised/run_qmix.py +++ b/examples/smac/recurrent/centralised/run_qmix.py @@ -95,7 +95,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_steps": 2000}, ).build() # Only the trainer should use the GPU (if available) diff --git a/examples/smac/recurrent/centralised/run_vdn.py b/examples/smac/recurrent/centralised/run_vdn.py index e0850bcfd..3d8f4a3fe 100644 --- a/examples/smac/recurrent/centralised/run_vdn.py +++ b/examples/smac/recurrent/centralised/run_vdn.py @@ -92,7 +92,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_steps": 2000}, ).build() # Only the trainer should use the GPU (if available) diff --git a/examples/smac/recurrent/decentralised/run_madqn.py b/examples/smac/recurrent/decentralised/run_madqn.py index 6dc3bf72d..1dd4aadce 100644 --- a/examples/smac/recurrent/decentralised/run_madqn.py +++ b/examples/smac/recurrent/decentralised/run_madqn.py @@ -94,7 +94,7 @@ def main(_: Any) -> None: samples_per_insert=4, sequence_length=20, period=10, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_steps": 2000}, trainer_fn=madqn.MADQNRecurrentTrainer, executor_fn=madqn.MADQNRecurrentExecutor, ).build() diff --git a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py index d285c31f0..f76eddac5 100644 --- a/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py +++ b/examples/smac/recurrent/decentralised/run_madqn_scale_trainers.py @@ -90,7 +90,7 @@ def main(_: Any) -> None: min_replay_size=32, batch_size=32, samples_per_insert=4, - evaluator_interval={"executor_episodes": 2}, + evaluator_interval={"executor_steps": 2000}, checkpoint_subpath=checkpoint_dir, ).build()