-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] - Add "burn-in" period to the training of stateful RLModules
.
#49680
base: master
Are you sure you want to change the base?
[RLlib] - Add "burn-in" period to the training of stateful RLModules
.
#49680
Conversation
…ries which we need for states. Signed-off-by: simonsays1980 <[email protected]>
…tatesFromEpisodesToBatch' connector adjusts the sequence length to the burn-in period and the 'EpsiodeReplayBuffer' allows for a burn-in period to be added in front of the 'batch_length_T'. The 'GeneralAdvantageEstimation' connector needed to be adjusted to the added burnin period. Signed-off-by: simonsays1980 <[email protected]>
rllib/algorithms/algorithm_config.py
Outdated
@@ -397,6 +397,7 @@ def __init__(self, algo_class: Optional[type] = None): | |||
self._learner_connector = None | |||
self.add_default_connectors_to_learner_pipeline = True | |||
self.learner_config_dict = {} | |||
self.burnin = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's rename to "burn_in_len" for more clarity.
High-level considerations:
- Let's make burn-in a DQN (or other off-policy/offline) only feature for now.
- Let's make the meaning of burn-in: "burn into the already configured
max_seq_len
". Meaning, instead of addingmax_seq_len
+burn_in_len
, we would be able to continue with simplymax_seq_len
(e.g. in the buffer), b/c this already includes the burn-in length. - We can add to config.validate() that
max_seq_len
must be at leastburn_in_len
+ 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, having burn-in be a part of max_seq_len
provides a better guarantee that max_seq_len
- from a model's perspective - is really the max(!) seq len. Just in case the model has some internal limit on processing sequences. We don't want to accidentally breach it b/c we are adding the burn-in to the max_seq_len (instead of making it a part of it).
rllib/algorithms/dqn/dqn.py
Outdated
@@ -655,6 +655,7 @@ def _training_step_new_api_stack(self): | |||
batch_length_T=self.env_runner.module.is_stateful() | |||
* self.config.model_config.get("max_seq_len", 0), | |||
lookback=int(self.env_runner.module.is_stateful()), | |||
burnin=self.config.burnin, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so I think with the suggestions above, we won't even need to change the buffer anymore, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically, we would not to. My concerns are more about the cases where the sampled sequence length would be smaller than the burn_in_len
- in this case the mask would be False
for all time steps in that sequence in training.
mask = batch[Columns.LOSS_MASK] | ||
mask = batch[Columns.LOSS_MASK].clone() | ||
# Check, if a burn-in should be used to recover from a poor state. | ||
if self.config.burnin > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!
@@ -221,7 +221,8 @@ def sample( | |||
sample_episodes: Optional[bool] = False, | |||
finalize: bool = False, | |||
# TODO (simon): Check, if we need here 1 as default. | |||
lookback: Optional[int] = 0, | |||
lookback: int = 0, | |||
burnin: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually. See my comment above.
@@ -268,6 +269,11 @@ def sample( | |||
timestep at which the action is computed). | |||
finalize: If episodes should be finalized. | |||
lookback: A desired lookback. Any non-negative integer is valid. | |||
burnin: An optional burn-in length added to the `batch_length_T` when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed?
@@ -287,6 +293,7 @@ def sample( | |||
include_extra_model_outputs=include_extra_model_outputs, | |||
finalize=finalize, | |||
lookback=lookback, | |||
burnin=burnin, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed?
@@ -434,6 +441,7 @@ def _sample_episodes( | |||
include_extra_model_outputs: bool = False, | |||
finalize: bool = False, | |||
lookback: Optional[int] = 1, | |||
burnin: Optional[int] = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed?
@@ -480,6 +488,11 @@ def _sample_episodes( | |||
timestep at which the action is computed). | |||
finalize: If episodes should be finalized. | |||
lookback: A desired lookback. Any non-negative integer is valid. | |||
burnin: An optional burn-in length added to the `batch_length_T` when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed?
@@ -537,6 +550,8 @@ def _sample_episodes( | |||
|
|||
# Skip, if we are too far to the end and `episode_ts` + n_step would go | |||
# beyond the episode's end. | |||
if burnin > 0 and episode_ts + burnin >= len(episode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the place where my comment above would point to. We check here, if we have a burnin and guarantee that the sampled sequence is at least as long as the required burn-in b/c otherwise the full mask would be False
and we would have sampled for no update.
… moved i to 'DQNConfig' from 'AlgorithmConfig'. Added a validation for 'max_seq_len' to be larger than 'burn_in_len' + 1 to include the burn-in directly into the 'max_seq_len' and remove all changes from connectors. Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for making these changes.
Signed-off-by: Sven Mika <[email protected]>
…' and adjusted docstring. Added safeguard for algorithms like SAC that do not have a 'burn_in_len', yet. Signed-off-by: simonsays1980 <[email protected]>
…n' is in the 'model_config' was not covered. Signed-off-by: simonsays1980 <[email protected]>
… 'burn_in_len' B/c not all inherited algorithms have this attribute. Signed-off-by: simonsays1980 <[email protected]>
Why are these changes needed?
Stateful models compute predictions based on (a) hidden state(s). A major requirement for good predictions is the quality of this/these hidden state(s). Specifically in off-policy training states suffer under staleness (this happens as well in on-policy training when running multiple epochs on a train batch). Because the input state(s) to a model for a certain observation - when predicting on sequences - is basically the last output state, we can improve the quality of the state when using a certain "burn-in" period of the sequence until training starts.
This PR adds "burn-in" to stateful
RLModule
s for DQN and PPO. The following changes are proposed:burnin
parameter to theAlgorithmConfig
to configure the "burn-in" period and to have this parameter available in theLearner
.burnin
parameter via theAlgorithmConfig._model_config_auto_includes
to theRLModule
'smodel_config
to make the parameter available inConnector
s.burnin
period to themax_seq_len
in theAddStatesFromEpisodesToBatch
connector to extend the sequences in the batch.burnin
period also to themax_seq_len
in theGeneralAdvantageEstimation
to extend postprocessed data.burnin
argument for theEpisodeReplayBuffer
that is added tobatch_sequence_length
when sampling sequences.EpisodeReplayBuffer
that do not extend over theburnin
period to avoid training on empty batches.DQNTorchLearner
andPPOTorchLearner
to train only on the sequence part after burn-in.DQNTorchLearner
max
andmin
calculations for potential empty tensors.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.