diff --git a/mava/utils/wrapper_utils.py b/mava/utils/wrapper_utils.py index b0d91ebe7..e43c06b5d 100644 --- a/mava/utils/wrapper_utils.py +++ b/mava/utils/wrapper_utils.py @@ -172,6 +172,7 @@ class RunningStatistics: a specific quantity. """ + # The queue_size is used to estimate a moving mean and variance value. def __init__(self, label: str, queue_size: int = 100) -> None: self.queue: collections.deque = collections.deque(maxlen=queue_size) diff --git a/mava/wrappers/environment_loop_wrappers.py b/mava/wrappers/environment_loop_wrappers.py index 4fed09957..0e5e46b75 100644 --- a/mava/wrappers/environment_loop_wrappers.py +++ b/mava/wrappers/environment_loop_wrappers.py @@ -174,18 +174,12 @@ def __init__( f"{agent}_episode_return" ) self._agents_stats[agent]["reward"] = RunningStatistics( - f"{agent}_episode_reward" + f"{agent}_step_reward" ) def _compute_step_statistics(self, rewards: Dict[str, float]) -> None: for agent, reward in rewards.items(): - agent_running_statistics: Dict[str, float] = {} self._agents_stats[agent]["reward"].push(reward) - for stat in self._summary_stats: - agent_running_statistics[ - f"{agent}_{stat}_step_reward" - ] = self._agents_stats[agent]["reward"].__getattribute__(stat)() - self._agent_loggers[agent].write(agent_running_statistics) def _compute_episode_statistics( self, @@ -211,17 +205,24 @@ def _compute_episode_statistics( f"_{metric}_stats" ).__getattribute__(stat)() + self._running_statistics.update({"episode_length": episode_steps}) + self._running_statistics.update(counts) + + # Write per agent statistics for agent, agent_return in episode_returns.items(): agent_running_statistics: Dict[str, float] = {} self._agents_stats[agent]["return"].push(agent_return) for stat in self._summary_stats: + # Episode return agent_running_statistics[f"{agent}_{stat}_return"] = self._agents_stats[ agent ]["return"].__getattribute__(stat)() - self._agent_loggers[agent].write(agent_running_statistics) - self._running_statistics.update({"episode_length": episode_steps}) - self._running_statistics.update(counts) + # Step rewards + agent_running_statistics[ + f"{agent}_{stat}_step_reward" + ] = self._agents_stats[agent]["reward"].__getattribute__(stat)() + self._agent_loggers[agent].write(agent_running_statistics) class MonitorParallelEnvironmentLoop(ParallelEnvironmentLoop):