Skip to content

Commit

Permalink
support TP-only parallelism
Browse files Browse the repository at this point in the history
ghstack-source-id: c13ebb8de8e8e9203624b5dd710a046d17311b0f
Pull Request resolved: #137
  • Loading branch information
tianyu-l committed Mar 14, 2024
1 parent 3262a8b commit d9253ee
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def build_optimizer(model, job_config: JobConfig):

def build_grad_scaler(model):
# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
if model.mixed_precision.param_dtype == torch.float16:
# NOTE: currently mixed precision training is supported only when FSDP is used
if isinstance(model, FSDP) and model.mixed_precision.param_dtype == torch.float16:
enable_grad_scaling = True
logger.info("Enabling gradient scaling for mixed precision training")
else:
Expand Down Expand Up @@ -160,13 +161,12 @@ def main(job_config: JobConfig):
model, world_mesh, parallel_dims, job_config
)

# to use FSDP-customized gradient scaler and gradient clipping solutions
assert isinstance(model, FSDP)

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

# build grad scaler which is effective only when mixed precision training
# is enabled with fp16 param dtype under FSDP
scaler = build_grad_scaler(model)

metric_logger = build_metric_logger(job_config)
Expand Down Expand Up @@ -240,7 +240,12 @@ def main(job_config: JobConfig):

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
model.clip_grad_norm_(job_config.training.max_norm)
if isinstance(model, FSDP):
model.clip_grad_norm_(job_config.training.max_norm)
else:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm
)

# optimizer step
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
Expand Down

0 comments on commit d9253ee

Please sign in to comment.