From d9253eef66acd1b145f0c036fab2e89e85bad6a3 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 13 Mar 2024 15:17:56 -0700 Subject: [PATCH] support TP-only parallelism ghstack-source-id: c13ebb8de8e8e9203624b5dd710a046d17311b0f Pull Request resolved: https://github.com/pytorch/torchtrain/pull/137 --- train.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 69cea633b..3f37a61b0 100644 --- a/train.py +++ b/train.py @@ -76,7 +76,8 @@ def build_optimizer(model, job_config: JobConfig): def build_grad_scaler(model): # apply gradient scaling if mixed precision training is enabled with fp16 param dtype - if model.mixed_precision.param_dtype == torch.float16: + # NOTE: currently mixed precision training is supported only when FSDP is used + if isinstance(model, FSDP) and model.mixed_precision.param_dtype == torch.float16: enable_grad_scaling = True logger.info("Enabling gradient scaling for mixed precision training") else: @@ -160,13 +161,12 @@ def main(job_config: JobConfig): model, world_mesh, parallel_dims, job_config ) - # to use FSDP-customized gradient scaler and gradient clipping solutions - assert isinstance(model, FSDP) - # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) + # build grad scaler which is effective only when mixed precision training + # is enabled with fp16 param dtype under FSDP scaler = build_grad_scaler(model) metric_logger = build_metric_logger(job_config) @@ -240,7 +240,12 @@ def main(job_config: JobConfig): # clip gradients (after unscaling gradients of the optimizer's params) scaler.unscale_(optimizer) - model.clip_grad_norm_(job_config.training.max_norm) + if isinstance(model, FSDP): + model.clip_grad_norm_(job_config.training.max_norm) + else: + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm + ) # optimizer step # If gradients don't contain infs/NaNs, optimizer.step() is then called;