diff --git a/train.py b/train.py index 69cea633b..3f37a61b0 100644 --- a/train.py +++ b/train.py @@ -76,7 +76,8 @@ def build_optimizer(model, job_config: JobConfig): def build_grad_scaler(model): # apply gradient scaling if mixed precision training is enabled with fp16 param dtype - if model.mixed_precision.param_dtype == torch.float16: + # NOTE: currently mixed precision training is supported only when FSDP is used + if isinstance(model, FSDP) and model.mixed_precision.param_dtype == torch.float16: enable_grad_scaling = True logger.info("Enabling gradient scaling for mixed precision training") else: @@ -160,13 +161,12 @@ def main(job_config: JobConfig): model, world_mesh, parallel_dims, job_config ) - # to use FSDP-customized gradient scaler and gradient clipping solutions - assert isinstance(model, FSDP) - # build optimizer after applying parallelisms to the model optimizer = build_optimizer(model, job_config) scheduler = get_lr_scheduler(optimizer, job_config) + # build grad scaler which is effective only when mixed precision training + # is enabled with fp16 param dtype under FSDP scaler = build_grad_scaler(model) metric_logger = build_metric_logger(job_config) @@ -240,7 +240,12 @@ def main(job_config: JobConfig): # clip gradients (after unscaling gradients of the optimizer's params) scaler.unscale_(optimizer) - model.clip_grad_norm_(job_config.training.max_norm) + if isinstance(model, FSDP): + model.clip_grad_norm_(job_config.training.max_norm) + else: + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm + ) # optimizer step # If gradients don't contain infs/NaNs, optimizer.step() is then called;