diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2925dad83607..6a065453aac6 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -119,7 +119,7 @@ def init_model_parallel( sharp: bool, nccl_communicator_config_path: str = None, distributed_timeout_minutes: int = 30 ) -> None: - """ Initializes Megatron-LM model parallel if using model parallelism. + """Initializes Megatron-LM model parallel if using model parallelism. Args: sharp: Apply SHARP to NCCL data-parallel communication. @@ -163,7 +163,7 @@ def init_model_parallel( class NLPDDPStrategy(DDPStrategy): - """ DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models. + """DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models. Args: no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2 @@ -230,8 +230,8 @@ def setup_distributed(self, global_rank: int = None, world_size: int = None) -> ) def configure_ddp(self): - """ Override LightningModule ddp if using model parallel. - Sets find_unused_parameters to False to use activation-checkpoint-recomputation. + """Override LightningModule ddp if using model parallel. + Sets find_unused_parameters to False to use activation-checkpoint-recomputation. """ if (hasattr(self.model, 'megatron_amp_O2') and self.model.megatron_amp_O2) or ( @@ -405,7 +405,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict) def _fix_tensors_device(self, ckpt: Dict) -> Dict: - """ Ensure checkpoint tensors are on the correct device.""" + """Ensure checkpoint tensors are on the correct device.""" assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized()) cur_dev = torch.device("cuda", index=torch.cuda.current_device()) @@ -417,10 +417,10 @@ def _fix_device(t): return dict_list_map_outplace(_fix_device, ckpt) def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - """ PTL method which we override to integrate distributed checkpoints for model parallel models. - In order to load distributed checkpoints we need to provide the sharded_state_dict to - the distributed load function. We get the sharded_state_dict from self.lightning_module - which makes it convenient to have the loading logic happen at the strategy level. + """PTL method which we override to integrate distributed checkpoints for model parallel models. + In order to load distributed checkpoints we need to provide the sharded_state_dict to + the distributed load function. We get the sharded_state_dict from self.lightning_module + which makes it convenient to have the loading logic happen at the strategy level. """ fs = get_filesystem(checkpoint_path) @@ -499,15 +499,15 @@ def distributed_sampler_kwargs(self): @property def restore_checkpoint_after_setup(self) -> bool: - """ This needs to be True for distributed checkpointing because - we require the model to have configured the optimizer before - deserializing the checkpoint. + """This needs to be True for distributed checkpointing because + we require the model to have configured the optimizer before + deserializing the checkpoint. """ return True class NLPDDPStrategyNotebook(NLPDDPStrategy): - """ Version of NLPDDPStrategy to be used in a Jupyter Notebook + """Version of NLPDDPStrategy to be used in a Jupyter Notebook A large portion of Megatron code has DDP dependency, so it has been necessary to use NLPDDPStrategy even for single-GPU training (e.g. in a Jupyter notebook) A PTL 2.0 changes has prevented DDPStrategy to be used in a notebook. @@ -545,7 +545,7 @@ def _get_full_state_dict_context(module: torch.nn.Module, rank0_only: bool = Fal class NLPFSDPStrategy(FSDPStrategy): - """ FSDP plugin for Pytorch Lightning with the support for tensor-parallelism. + """FSDP plugin for Pytorch Lightning with the support for tensor-parallelism. Args: sharding_strategy: FSDP parameter sharding strategy. @@ -641,7 +641,11 @@ def _set_mixed_precision_recipe( reduce_dtype = utils_funcs.torch_dtype_from_precision(grad_reduce_dtype, None) if set_buffer_dtype is not None: buffer_dtype = utils_funcs.torch_dtype_from_precision(buffer_dtype, None) - return MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype,) + return MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) def setup_environment(self) -> None: """ @@ -752,7 +756,9 @@ def _get_osd(opt_state): with FSDP.summon_full_params(self.model, writeback=True, rank0_only=False): # rekey the osd stored from non-FSDP model rekeyed_osd = FSDP.rekey_optim_state_dict( - temp_osd, OptimStateKeyType.PARAM_NAME, self.model, + temp_osd, + OptimStateKeyType.PARAM_NAME, + self.model, ) temp_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, self.model) except Exception as e: @@ -760,7 +766,9 @@ def _get_osd(opt_state): exit(1) # Shard optimizer state dict sharded_osd = FSDP.optim_state_dict_to_load( - optim_state_dict=temp_osd, model=self.model, optim=optimizer, + optim_state_dict=temp_osd, + model=self.model, + optim=optimizer, ) optimizer.load_state_dict(sharded_osd) @@ -769,9 +777,9 @@ def _get_osd(opt_state): def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None ) -> None: - """ Store checkpoints - 1. In case of sharded checkpoint, all ranks store unique checkpoints. - 2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints. + """Store checkpoints + 1. In case of sharded checkpoint, all ranks store unique checkpoints. + 2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints. """ app_state = AppState() filepath = inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=self.sharded_checkpoint) @@ -782,8 +790,7 @@ def save_checkpoint( self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - """ Load checkpoints - """ + """Load checkpoints""" # 1. Load normal or FSDP-sharded checkpoints. fs = get_filesystem(checkpoint_path) @@ -800,8 +807,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: return checkpoint def remove_checkpoint(self, filepath: Union[str, Path]) -> None: - """ Remove checkpoints - """ + """Remove checkpoints""" # legacy checkpoint logic, does not use megatron core app_state = AppState() # PTL override to accomodate model parallel checkpoints @@ -816,9 +822,9 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None: @property def restore_checkpoint_after_setup(self) -> bool: - """ When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring - FSDP sharding to match FSDP-sharded format between the checkpoint and the current - model and optimizer. + """When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring + FSDP sharding to match FSDP-sharded format between the checkpoint and the current + model and optimizer. """ return True @@ -917,7 +923,8 @@ def dummy(): else: # move weights to the tmpdir for tp_rank, pp_rank in itertools.product( - range(app_state.tensor_model_parallel_size), range(app_state.pipeline_model_parallel_size), + range(app_state.tensor_model_parallel_size), + range(app_state.pipeline_model_parallel_size), ): os.makedirs(os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}')) mp_model_weights = os.path.join( @@ -1002,6 +1009,7 @@ def modify_state_dict(self, conf, state_dict): loaded_keys = state_dict.keys() if 'model.model.diffusion_model.input_blocks.1.0.in_layers.2.weight' in loaded_keys: new_state_dict = {} + # GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following def should_process(key): base_str = "model.model.diffusion_model." @@ -1112,7 +1120,13 @@ def restore_from( # Get path where the command is executed - the artifacts will be "retrieved" there # (original .nemo behavior) loaded_params = super().load_config_and_state_dict( - calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer, + calling_cls, + restore_path, + override_config_path, + map_location, + strict, + return_config, + trainer, ) if not isinstance(loaded_params, tuple) or return_config is True: return loaded_params @@ -1167,12 +1181,12 @@ def dummy(): class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin): - """ Overrides PTL autocasting to not wrap training/val/test_step. - We do this because we have the megatron-core fwd/bwd functions in training_step. - This means .backward is being called in training_step so we do not want the whole - step wrapped in autocast. + """Overrides PTL autocasting to not wrap training/val/test_step. + We do this because we have the megatron-core fwd/bwd functions in training_step. + This means .backward is being called in training_step so we do not want the whole + step wrapped in autocast. - We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. + We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. """ def __init__( @@ -1208,12 +1222,12 @@ def forward_context(self) -> Generator[None, None, None]: class FSDPMixedPrecisionPlugin(FSDPPrecision): - """ Overrides PTL autocasting to not wrap training/val/test_step. - We do this because we have the megatron-core fwd/bwd functions in training_step. - This means .backward is being called in training_step so we do not want the whole - step wrapped in autocast. + """Overrides PTL autocasting to not wrap training/val/test_step. + We do this because we have the megatron-core fwd/bwd functions in training_step. + This means .backward is being called in training_step so we do not want the whole + step wrapped in autocast. - We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. + We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. """ def __init__( @@ -1248,7 +1262,7 @@ class GradScaler(torch.cuda.amp.GradScaler): def __init__( self, - init_scale=2.0 ** 16, + init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, @@ -1502,7 +1516,7 @@ def optimizer_step( @contextmanager def forward_context(self) -> Generator[None, None, None]: - """ No explicit precision casting. Inputs are supposed to be manually casted """ + """No explicit precision casting. Inputs are supposed to be manually casted""" try: yield finally: @@ -1510,7 +1524,7 @@ def forward_context(self) -> Generator[None, None, None]: class GlobalBatchDataFetcher(_DataFetcher): - """ Overrides PTL DataFetcher. Used to fetch global batches.""" + """Overrides PTL DataFetcher. Used to fetch global batches.""" def __init__(self, prefetch_batches: int = 0, store_on_device: bool = False) -> None: