Skip to content

Commit

Permalink
Abstract out out optimizer params and update foreach calling conventi…
Browse files Browse the repository at this point in the history
…on (pytorch#386)

# Summary
Updates the behavior to call foreach when we are not using fused for the
optimizer
  • Loading branch information
drisspg authored Jun 7, 2024
1 parent 3bbe3d9 commit d953107
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,20 @@ def build_optimizer(model, job_config: JobConfig):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused
# when fused = False, foreach = True by default.

# Common parameters for both optimizers
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
}
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1, fused=fused
)
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1, fused=fused
)
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")

Expand Down

0 comments on commit d953107

Please sign in to comment.