diff --git a/Dockerfile b/Dockerfile index d796ef055558..d68c4a885886 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,11 +34,11 @@ RUN apt-get update && \ WORKDIR /tmp/ -# TODO: Remove once this Apex commit (1/19/23) is included in PyTorch +# TODO: Remove once this Apex commit is included in PyTorch # container -RUN git clone https://github.com/NVIDIA/apex.git && \ +RUN git clone https://github.com/timmoon10/apex.git && \ cd apex && \ - git checkout c0a0b0f69d2d5a98bd141be12ee8e5eebd3ec7ca && \ + git checkout dist-adam-bf16-grad-reductions && \ pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ # uninstall stuff from base container diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 7f2c0befce6b..9088b115ab1c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -330,20 +330,29 @@ def setup_optimization( optim_kwargs['contiguous_grad_buffer'] = True optim_kwargs['contiguous_param_buffer'] = True - if self.megatron_amp_o2: - # Match param allgather with model dtype - if hasattr(self, 'autocast_dtype'): - optim_kwargs['param_sync_dtype'] = self.autocast_dtype - if self.autocast_dtype == torch.float: - optim_kwargs['store_params'] = False - elif self.autocast_dtype == torch.float16: - optim_kwargs['store_params'] = True - elif self.autocast_dtype == torch.bfloat16: - optim_kwargs['store_params'] = False - optim_kwargs['store_param_remainders'] = True - else: - # Assume FP32 params, so no need to store main params + # Make sure optimizer state is in FP32 + optim_dtype = torch.float32 + optim_kwargs['dtype'] = optim_dtype + + # Make sure embedding grad reductions are in FP32 + for name, param in self.named_parameters(): + if 'word_embedding' in name or 'position_embedding' in name: + param._with_fp32_optimizer = True + + # Match param allgather with model dtype + model_dtype = torch.float32 + if self.megatron_amp_o2 and hasattr(self, 'autocast_dtype'): + model_dtype = self.autocast_dtype + optim_kwargs['param_sync_dtype'] = model_dtype + + # Determine whether to store master params in optimizer + if optim_dtype == model_dtype: optim_kwargs['store_params'] = False + elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16: + optim_kwargs['store_params'] = False + optim_kwargs['store_param_remainders'] = True + else: + optim_kwargs['store_params'] = True return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs) diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py index cdcb8b03e810..c1c4bb68e80a 100644 --- a/nemo/collections/nlp/modules/common/megatron/clip_grads.py +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -173,7 +173,6 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2): Total norm of the parameters (viewed as a single vector). """ - assert norm_type == 2 assert isinstance(optimizer, DistributedFusedAdam) # Filter parameters based on: @@ -188,20 +187,8 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2): params_for_norm.append(param) # Compute grad norm - # Note: Compute norm of local grads and sum over all procs - grad_norm_sq = optimizer._local_grad_norm(parameters=params_for_norm, norm_type=norm_type) - if optimizer.redundant_size > 1: - grad_norm_sq /= optimizer.redundant_size - torch.distributed.all_reduce( - grad_norm_sq, op=torch.distributed.ReduceOp.SUM, - ) - grad_norm = grad_norm_sq.sqrt() - - # Apply gradient clipping - # Note: DistributedFusedAdam is only aware of the data-parallel - # process group, so we cannot directly apply its gradient clipping - # function. However, it caches the grad norm to avoid redundant - # communication, so it suffices to overwrite the cache with the - # grad norm computed over the world parallel group. - optimizer._grad_norm = grad_norm + # Note: DistributedFusedAdam caches grad norm to avoid redundant + # communication. + optimizer.grad_norm(parameters=params_for_norm, norm_type=norm_type) + return optimizer.clip_grad_norm(max_norm, norm_type=norm_type) diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index ae10fc51823a..d1fddf7b046d 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -16,11 +16,31 @@ from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook from apex.transformer import parallel_state +def _str_to_dtype(dtype): + if isinstance(dtype, torch.dtype): + return dtype + name = str(dtype).strip().lower() + if name in ('', 'none'): + return torch.float32 + elif name in ('torch.float32', 'float32', 'float', 'fp32', '32'): + return torch.float32 + elif name in ('torch.float16', 'float16', 'half', 'fp16', '16'): + return torch.float16 + elif name in ('torch.bfloat16', 'bfloat16', 'bf16'): + return torch.bfloat16 + else: + raise ValueError(f'unsupported dtype ({dtype})') -# Wrapper class that supports main_grad buffer -# Note: main_grad buffer is used for O2-style optimizations class MegatronDistributedFusedAdam(DistributedFusedAdam): - def __init__(self, *args, disable_distributed_parameters=False, **kwargs): + """Wrapper class that supports NeMo-Megatron optimizations + + When O2-style optimizations are enabled, gradients are accumulated + into the main_grad buffer instead of the grad buffer. + + """ + def __init__(self, params, disable_distributed_parameters=False, **kwargs): + + # Initialize process groups if 'process_group' not in kwargs and not parallel_state.is_unitialized(): kwargs['process_group'] = parallel_state.get_data_parallel_group() if disable_distributed_parameters: @@ -29,7 +49,46 @@ def __init__(self, *args, disable_distributed_parameters=False, **kwargs): self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)] kwargs['distributed_process_group'] = self_groups[rank] kwargs['redundant_process_group'] = kwargs['process_group'] - super().__init__(*args, **kwargs) + + # Make sure dtypes are in right type + for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'): + if keyword in kwargs: + kwargs[keyword] = _str_to_dtype(kwargs[keyword]) + + # Check if any parameters require an explicit FP32 optimizer + self._fp32_optim = None + distopt_params = params + dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32 + grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype + if ( + (dtype != torch.float32 or grad_sync_dtype != torch.float32) + and any(getattr(param, '_with_fp32_optimizer', False) for param in params) + ): + + # Find params that require explicit FP32 optimizer + self._fp32_optim_model_params = [] + self._fp32_optim_main_params = [] + distopt_params = [] + for model_param in params: + if getattr(param, '_with_fp32_optimizer', False): + main_param = param.detach().clone().float() + self._fp32_optim_model_params.append(model_param) + self._fp32_optim_main_params.append(main_param) + else: + distopt_params.append(model_param) + + # Construct explicit FP32 optimizer + adamw_kwargs = {} + for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'): + if name in kwargs: + adamw_kwargs[name] = kwargs[name] + self.fp32_optim = torch.optim.AdamW( + self._fp32_optim_main_params, + **adamw_kwargs, + ) + + # Construct distributed optimizer + super().__init__(distopt_params, **kwargs) def _make_post_backward_hook(self, param, param_group_id, param_id): def hook(*unused): @@ -60,9 +119,104 @@ def try_grad_sync(self, params): self._grad_copy(p) self._try_start_bucket_grad_sync(params=params) + def _fp32_optim_grad_sync(self): + if self._fp32_optim is None: + return + for model_param, main_param in zip(self._fp32_optim_model_params, self._fp32_optim_main_params): + if main_param.grad is None: + main_param.grad = model_param.grad.detach().clone().float() + torch.distributed.all_reduce(main_param.grad, group=self.process_group) + def zero_grad(self, *args, **kwargs): super().zero_grad(*args, **kwargs) + + # Reset grads for explicit FP32 optimizer + if self._fp32_optim is not None: + self._fp32_optim.zero_grad(set_to_none=True) + for param in self._fp32_optim_model_params: + param.grad = None + + # Reset main grads if self.contiguous_grad_buffer: for param in self.parameters(): with _disable_pre_forward_hook(param): param.main_grad = self.grad_buffer_view(param) + + def grad_norm(self, parameters=None, norm_type=2.0, force=False): + assert norm_type == 2 + + # Compute grad norm + if force or self._grad_norm is None: + + # Identify params for explicit FP32 optimizer + if self._fp32_optim is not None: + if parameters is None: + fp32_optim_params = self._fp32_optim_model_params + else: + fp32_optim_params = [ + param for param in parameters + if param in self._fp32_optim_model_params + ] + parameters = [ + param for param in parameters + if param not in self._fp32_optim_model_params + ] + + # Compute norm of local gradients for distributed optimizer + grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type) + if self.redundant_size > 1: + grad_norm_sq /= self.redundant_size + + # Compute norm of local gradients for explicit FP32 optimizer + if self._fp32_optim is not None: + _fp32_optim_grad_sync() + for model_param in fp32_optim_params: + i = self._fp32_optim_model_params.index(model_param) + main_param = self._fp32_optim_main_params[i] + grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size + + # Sum over all procs to get grad norm + torch.distributed.all_reduce( + grad_norm_sq, op=torch.distributed.ReduceOp.SUM, + ) + self._grad_norm = grad_norm_sq.sqrt() + + # Use cached grad norm + return super().grad_norm() + + def step(self, closure=None, *, grad_scaler=None): + + # Apply distributed optimizer + loss = super().step(closure=closure, grad_scaler=grad_scaler) + + if self._fp32_optim is not None: + + # Handle grad scaling + if grad_scaler is not None: + scaler_state = grad_scaler._per_optimizer_states[id(self)] + for _, found_inf in scaler_state['found_inf_per_device'].items(): + if found_inf.item(): + return + + # Apply explicit FP32 optimizer + self._fp32_optim_grad_sync() + for main_param in self._fp32_optim_main_params: + main_param.grad *= self._grad_scale + self._fp32_optim.step() + for model_param, main_param in zip(self._fp32_optim_model_params, self._fp32_optim_main_params): + main_param.grad = None + model_param.copy_(main_param.detach()) + + return loss + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + if self._fp32_optim is not None and state_dict is not None: + state_dict['fp32_optim'] = self._fp32_optim.state_dict() + return state_dict + + def load_state_dict(self, state_dict): + if self._fp32_optim is not None and 'fp32_optim' in state_dict: + self._fp32_optim.load_state_dict(state_dict['fp32_optim']) + del state_dict['fp32_optim'] + return super().load_state_dict(state_dict)