Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] - Add "burn-in" period to the training of stateful RLModules. #49680

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

simonsays1980
Copy link
Collaborator

Why are these changes needed?

Stateful models compute predictions based on (a) hidden state(s). A major requirement for good predictions is the quality of this/these hidden state(s). Specifically in off-policy training states suffer under staleness (this happens as well in on-policy training when running multiple epochs on a train batch). Because the input state(s) to a model for a certain observation - when predicting on sequences - is basically the last output state, we can improve the quality of the state when using a certain "burn-in" period of the sequence until training starts.

This PR adds "burn-in" to stateful RLModules for DQN and PPO. The following changes are proposed:

  • Add a burnin parameter to the AlgorithmConfig to configure the "burn-in" period and to have this parameter available in the Learner.
  • Add the burnin parameter via the AlgorithmConfig._model_config_auto_includes to the RLModule's model_config to make the parameter available in Connectors.
  • Add the burnin period to the max_seq_len in the AddStatesFromEpisodesToBatch connector to extend the sequences in the batch.
  • Add the burnin period also to the max_seq_len in the GeneralAdvantageEstimation to extend postprocessed data.
  • Make available an burnin argument for the EpisodeReplayBuffer that is added to batch_sequence_length when sampling sequences.
  • Safeguard sampling sequences in the EpisodeReplayBuffer that do not extend over the burnin period to avoid training on empty batches.
  • Adjust the DQNTorchLearner and PPOTorchLearner to train only on the sequence part after burn-in.
  • Safeguard the DQNTorchLearner max and min calculations for potential empty tensors.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…ries which we need for states.

Signed-off-by: simonsays1980 <[email protected]>
…tatesFromEpisodesToBatch' connector adjusts the sequence length to the burn-in period and the 'EpsiodeReplayBuffer' allows for a burn-in period to be added in front of the 'batch_length_T'. The 'GeneralAdvantageEstimation' connector needed to be adjusted to the added burnin period.

Signed-off-by: simonsays1980 <[email protected]>
@simonsays1980 simonsays1980 marked this pull request as ready for review January 7, 2025 11:48
@@ -397,6 +397,7 @@ def __init__(self, algo_class: Optional[type] = None):
self._learner_connector = None
self.add_default_connectors_to_learner_pipeline = True
self.learner_config_dict = {}
self.burnin = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's rename to "burn_in_len" for more clarity.

High-level considerations:

  • Let's make burn-in a DQN (or other off-policy/offline) only feature for now.
  • Let's make the meaning of burn-in: "burn into the already configured max_seq_len". Meaning, instead of adding max_seq_len + burn_in_len, we would be able to continue with simply max_seq_len (e.g. in the buffer), b/c this already includes the burn-in length.
  • We can add to config.validate() that max_seq_len must be at least burn_in_len + 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, having burn-in be a part of max_seq_len provides a better guarantee that max_seq_len - from a model's perspective - is really the max(!) seq len. Just in case the model has some internal limit on processing sequences. We don't want to accidentally breach it b/c we are adding the burn-in to the max_seq_len (instead of making it a part of it).

@@ -655,6 +655,7 @@ def _training_step_new_api_stack(self):
batch_length_T=self.env_runner.module.is_stateful()
* self.config.model_config.get("max_seq_len", 0),
lookback=int(self.env_runner.module.is_stateful()),
burnin=self.config.burnin,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I think with the suggestions above, we won't even need to change the buffer anymore, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically, we would not to. My concerns are more about the cases where the sampled sequence length would be smaller than the burn_in_len - in this case the mask would be False for all time steps in that sequence in training.

mask = batch[Columns.LOSS_MASK]
mask = batch[Columns.LOSS_MASK].clone()
# Check, if a burn-in should be used to recover from a poor state.
if self.config.burnin > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

@@ -221,7 +221,8 @@ def sample(
sample_episodes: Optional[bool] = False,
finalize: bool = False,
# TODO (simon): Check, if we need here 1 as default.
lookback: Optional[int] = 0,
lookback: int = 0,
burnin: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually. See my comment above.

@@ -268,6 +269,11 @@ def sample(
timestep at which the action is computed).
finalize: If episodes should be finalized.
lookback: A desired lookback. Any non-negative integer is valid.
burnin: An optional burn-in length added to the `batch_length_T` when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

@@ -287,6 +293,7 @@ def sample(
include_extra_model_outputs=include_extra_model_outputs,
finalize=finalize,
lookback=lookback,
burnin=burnin,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

@@ -434,6 +441,7 @@ def _sample_episodes(
include_extra_model_outputs: bool = False,
finalize: bool = False,
lookback: Optional[int] = 1,
burnin: Optional[int] = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

@@ -480,6 +488,11 @@ def _sample_episodes(
timestep at which the action is computed).
finalize: If episodes should be finalized.
lookback: A desired lookback. Any non-negative integer is valid.
burnin: An optional burn-in length added to the `batch_length_T` when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

@@ -537,6 +550,8 @@ def _sample_episodes(

# Skip, if we are too far to the end and `episode_ts` + n_step would go
# beyond the episode's end.
if burnin > 0 and episode_ts + burnin >= len(episode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the place where my comment above would point to. We check here, if we have a burnin and guarantee that the sampled sequence is at least as long as the required burn-in b/c otherwise the full mask would be False and we would have sampled for no update.

… moved i to 'DQNConfig' from 'AlgorithmConfig'. Added a validation for 'max_seq_len' to be larger than 'burn_in_len' + 1 to include the burn-in directly into the 'max_seq_len' and remove all changes from connectors.

Signed-off-by: simonsays1980 <[email protected]>
Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for making these changes.

@sven1977 sven1977 enabled auto-merge (squash) January 7, 2025 17:37
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Jan 7, 2025
…' and adjusted docstring. Added safeguard for algorithms like SAC that do not have a 'burn_in_len', yet.

Signed-off-by: simonsays1980 <[email protected]>
@github-actions github-actions bot disabled auto-merge January 7, 2025 22:47
…n' is in the 'model_config' was not covered.

Signed-off-by: simonsays1980 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants