Skip to content

Commit

Permalink
Log WandB tables on main process (#1951)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored Aug 20, 2024
1 parent 66d3a82 commit dc4cfab
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
11 changes: 6 additions & 5 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,13 +568,14 @@ def generate_completions(self, sampling: bool = False):
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:

if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if "wandb" in args.report_to:
import wandb

if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

@wraps(Trainer.push_to_hub)
def push_to_hub(
Expand Down
11 changes: 6 additions & 5 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,13 +604,14 @@ def generate_completions(self, sampling: bool = False):
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:

if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if "wandb" in args.report_to:
import wandb

if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

@wraps(Trainer.push_to_hub)
def push_to_hub(
Expand Down
11 changes: 6 additions & 5 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,14 @@ def generate_completions(self, sampling: bool = False):
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:

if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if "wandb" in args.report_to:
import wandb

if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

@wraps(Trainer.push_to_hub)
def push_to_hub(
Expand Down

0 comments on commit dc4cfab

Please sign in to comment.