From dc4cfab7006727f0fc76d09b69fdf04993e9f73a Mon Sep 17 00:00:00 2001 From: lewtun Date: Tue, 20 Aug 2024 16:42:51 +0200 Subject: [PATCH] Log WandB tables on main process (#1951) --- trl/trainer/online_dpo_trainer.py | 11 ++++++----- trl/trainer/ppov2_trainer.py | 11 ++++++----- trl/trainer/rloo_trainer.py | 11 ++++++----- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index b02c83b8c7..de628a525d 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -568,13 +568,14 @@ def generate_completions(self, sampling: bool = False): if sampling: break df = pd.DataFrame(table) - if self.accelerator.process_index == 0: + + if self.accelerator.is_main_process: print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb + if "wandb" in args.report_to: + import wandb - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) @wraps(Trainer.push_to_hub) def push_to_hub( diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 3d565a17e3..d2cbfba452 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -604,13 +604,14 @@ def generate_completions(self, sampling: bool = False): if sampling: break df = pd.DataFrame(table) - if self.accelerator.process_index == 0: + + if self.accelerator.is_main_process: print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb + if "wandb" in args.report_to: + import wandb - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) @wraps(Trainer.push_to_hub) def push_to_hub( diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 319597f2c7..ffd77e0159 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -502,13 +502,14 @@ def generate_completions(self, sampling: bool = False): if sampling: break df = pd.DataFrame(table) - if self.accelerator.process_index == 0: + + if self.accelerator.is_main_process: print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb + if "wandb" in args.report_to: + import wandb - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) @wraps(Trainer.push_to_hub) def push_to_hub(