generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Unify Policy Trainers #1586
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Awaiting unslothai/unsloth#533 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
WIP: Unify Policy Trainers
Overview / Problem
Many trainers within trl follow the same paradigm:
Trainers following this workflow include
PPOTrainer
,DPOTrainer
,KTOTrainer
, and the newRLOOTrainer
(in PR).Despite sharing these features, each trainer has repetitive and sometimes inconsistent implementations of core components including reference model management, generation of policy output, and even model saving.
This has resulted in a number of bugs, confusion, and unnecessary redundant work when implementing new policy trainers.
PolicyTrainerBase
The goal for this PR is to introduce an abstract
PolicyTrainerBase
with theRLOOTrainer
adapted from #1540The adapted
RLOOTrainer
only implementstraining_step()
which is provided a batch of inputs, calculates loss, applies backprop, and logs metrics.PolicyTrainerBase
takes care of everything else, primarily preparation and management of the reference model, along with preparation of the generation config, and a utility function for generation of output sequences and logits.I'll have to consider the generation function carefully, as that is one of the most complex components of the different policy trainers (see
PPOTrainer
s implementationtrl/trl/trainer/ppo_trainer.py
Lines 431 to 565 in b32656f
Remaining Work: