Skip to content

Commit

Permalink
Fix global step for consistent checkpointing with global updates (#1950)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored Aug 21, 2024
1 parent dc4cfab commit e529579
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def repeat_generator():
contain_eos_token = torch.any(postprocessed_responses == tokenizer.eos_token_id, dim=-1)
if args.non_eos_penalty:
scores = torch.where(contain_eos_token, scores, args.penalty_reward_value)
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")

# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
Expand Down Expand Up @@ -467,7 +466,6 @@ def repeat_generator():
] = rejected_logprobs_sum.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
self.state.global_step += 1
# del everything and empty cache
# fmt: off
del (
Expand Down Expand Up @@ -505,11 +503,11 @@ def repeat_generator():
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
del (kl, mean_kl, mean_entropy, scores, scores_margin)

self.lr_scheduler.step()
self.state.global_step += 1
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=metrics)
Expand Down

0 comments on commit e529579

Please sign in to comment.