From 3bdcab68d49b74411144c61df8e64e7f291f92e2 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Thu, 4 Jul 2024 11:00:36 +0200 Subject: [PATCH] [RLlib] Moving sampling coordination for `batch_mode=complete_episodes` to `synchronous_parallel_sample`. (#46321) --- .../tests/test_callbacks_on_env_runner.py | 3 --- rllib/env/single_agent_env_runner.py | 22 +++++++++---------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/rllib/algorithms/tests/test_callbacks_on_env_runner.py b/rllib/algorithms/tests/test_callbacks_on_env_runner.py index 898717c6c2b25..6afa874509e04 100644 --- a/rllib/algorithms/tests/test_callbacks_on_env_runner.py +++ b/rllib/algorithms/tests/test_callbacks_on_env_runner.py @@ -179,9 +179,6 @@ def test_episode_and_sample_callbacks_batch_mode_complete_episodes(self): # Train one iteration. algo.train() - # We must have had exactly one `sample()` call on our EnvRunner. - if not multi_agent: - self.assertEqual(callback_obj.counts["sample"], 1) # We should have had at least one episode start. self.assertGreater(callback_obj.counts["start"], 0) # Episode starts must be exact same as episode ends (b/c we always complete diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 86f86d6d32c5d..3700385e91821 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -197,19 +197,17 @@ def sample( explore=explore, random_actions=random_actions, ) - # For complete episodes mode, sample as long as the number of timesteps - # done is smaller than the `train_batch_size`. + # For complete episodes mode, sample a single episode and + # leave coordination of sampling to `synchronous_parallel_sample`. + # TODO (simon, sven): The coordination will eventually move + # to `EnvRunnerGroup` in the future. So from the algorithm one + # would do `EnvRunnerGroup.sample()`. else: - total = 0 - samples = [] - while total < self.config.train_batch_size: - episodes = self._sample_episodes( - num_episodes=self.num_envs, - explore=explore, - random_actions=random_actions, - ) - total += sum(len(e) for e in episodes) - samples.extend(episodes) + samples = self._sample_episodes( + num_episodes=1, + explore=explore, + random_actions=random_actions, + ) # Make the `on_sample_end` callback. self._callbacks.on_sample_end(