Skip to content

Commit

Permalink
Add support for BF16 grad reductions with distopt
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed Feb 9, 2023
1 parent ce1bf22 commit 59429ce
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 37 deletions.
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ RUN apt-get update && \

WORKDIR /tmp/

# TODO: Remove once this Apex commit (1/19/23) is included in PyTorch
# TODO: Remove once this Apex commit is included in PyTorch
# container
RUN git clone https://github.com/NVIDIA/apex.git && \
RUN git clone https://github.com/timmoon10/apex.git && \
cd apex && \
git checkout c0a0b0f69d2d5a98bd141be12ee8e5eebd3ec7ca && \
git checkout dist-adam-bf16-grad-reductions && \
pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./

# uninstall stuff from base container
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,20 +330,29 @@ def setup_optimization(
optim_kwargs['contiguous_grad_buffer'] = True
optim_kwargs['contiguous_param_buffer'] = True

if self.megatron_amp_o2:
# Match param allgather with model dtype
if hasattr(self, 'autocast_dtype'):
optim_kwargs['param_sync_dtype'] = self.autocast_dtype
if self.autocast_dtype == torch.float:
optim_kwargs['store_params'] = False
elif self.autocast_dtype == torch.float16:
optim_kwargs['store_params'] = True
elif self.autocast_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
# Assume FP32 params, so no need to store main params
# Make sure optimizer state is in FP32
optim_dtype = torch.float32
optim_kwargs['dtype'] = optim_dtype

# Make sure embedding grad reductions are in FP32
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name:
param._with_fp32_optimizer = True

# Match param allgather with model dtype
model_dtype = torch.float32
if self.megatron_amp_o2 and hasattr(self, 'autocast_dtype'):
model_dtype = self.autocast_dtype
optim_kwargs['param_sync_dtype'] = model_dtype

# Determine whether to store master params in optimizer
if optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
optim_kwargs['store_params'] = True

return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs)

Expand Down
21 changes: 4 additions & 17 deletions nemo/collections/nlp/modules/common/megatron/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2):
Total norm of the parameters (viewed as a single vector).
"""
assert norm_type == 2
assert isinstance(optimizer, DistributedFusedAdam)

# Filter parameters based on:
Expand All @@ -188,20 +187,8 @@ def clip_grad_norm_distributed_optimizer(optimizer, max_norm, norm_type=2):
params_for_norm.append(param)

# Compute grad norm
# Note: Compute norm of local grads and sum over all procs
grad_norm_sq = optimizer._local_grad_norm(parameters=params_for_norm, norm_type=norm_type)
if optimizer.redundant_size > 1:
grad_norm_sq /= optimizer.redundant_size
torch.distributed.all_reduce(
grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
)
grad_norm = grad_norm_sq.sqrt()

# Apply gradient clipping
# Note: DistributedFusedAdam is only aware of the data-parallel
# process group, so we cannot directly apply its gradient clipping
# function. However, it caches the grad norm to avoid redundant
# communication, so it suffices to overwrite the cache with the
# grad norm computed over the world parallel group.
optimizer._grad_norm = grad_norm
# Note: DistributedFusedAdam caches grad norm to avoid redundant
# communication.
optimizer.grad_norm(parameters=params_for_norm, norm_type=norm_type)

return optimizer.clip_grad_norm(max_norm, norm_type=norm_type)
162 changes: 158 additions & 4 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,31 @@
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam, _disable_pre_forward_hook
from apex.transformer import parallel_state

def _str_to_dtype(dtype):
if isinstance(dtype, torch.dtype):
return dtype
name = str(dtype).strip().lower()
if name in ('', 'none'):
return torch.float32
elif name in ('torch.float32', 'float32', 'float', 'fp32', '32'):
return torch.float32
elif name in ('torch.float16', 'float16', 'half', 'fp16', '16'):
return torch.float16
elif name in ('torch.bfloat16', 'bfloat16', 'bf16'):
return torch.bfloat16
else:
raise ValueError(f'unsupported dtype ({dtype})')

# Wrapper class that supports main_grad buffer
# Note: main_grad buffer is used for O2-style optimizations
class MegatronDistributedFusedAdam(DistributedFusedAdam):
def __init__(self, *args, disable_distributed_parameters=False, **kwargs):
"""Wrapper class that supports NeMo-Megatron optimizations
When O2-style optimizations are enabled, gradients are accumulated
into the main_grad buffer instead of the grad buffer.
"""
def __init__(self, params, disable_distributed_parameters=False, **kwargs):

# Initialize process groups
if 'process_group' not in kwargs and not parallel_state.is_unitialized():
kwargs['process_group'] = parallel_state.get_data_parallel_group()
if disable_distributed_parameters:
Expand All @@ -29,7 +49,46 @@ def __init__(self, *args, disable_distributed_parameters=False, **kwargs):
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
super().__init__(*args, **kwargs)

# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
if keyword in kwargs:
kwargs[keyword] = _str_to_dtype(kwargs[keyword])

# Check if any parameters require an explicit FP32 optimizer
self._fp32_optim = None
distopt_params = params
dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32
grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype
if (
(dtype != torch.float32 or grad_sync_dtype != torch.float32)
and any(getattr(param, '_with_fp32_optimizer', False) for param in params)
):

# Find params that require explicit FP32 optimizer
self._fp32_optim_model_params = []
self._fp32_optim_main_params = []
distopt_params = []
for model_param in params:
if getattr(param, '_with_fp32_optimizer', False):
main_param = param.detach().clone().float()
self._fp32_optim_model_params.append(model_param)
self._fp32_optim_main_params.append(main_param)
else:
distopt_params.append(model_param)

# Construct explicit FP32 optimizer
adamw_kwargs = {}
for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'):
if name in kwargs:
adamw_kwargs[name] = kwargs[name]
self.fp32_optim = torch.optim.AdamW(
self._fp32_optim_main_params,
**adamw_kwargs,
)

# Construct distributed optimizer
super().__init__(distopt_params, **kwargs)

def _make_post_backward_hook(self, param, param_group_id, param_id):
def hook(*unused):
Expand Down Expand Up @@ -60,9 +119,104 @@ def try_grad_sync(self, params):
self._grad_copy(p)
self._try_start_bucket_grad_sync(params=params)

def _fp32_optim_grad_sync(self):
if self._fp32_optim is None:
return
for model_param, main_param in zip(self._fp32_optim_model_params, self._fp32_optim_main_params):
if main_param.grad is None:
main_param.grad = model_param.grad.detach().clone().float()
torch.distributed.all_reduce(main_param.grad, group=self.process_group)

def zero_grad(self, *args, **kwargs):
super().zero_grad(*args, **kwargs)

# Reset grads for explicit FP32 optimizer
if self._fp32_optim is not None:
self._fp32_optim.zero_grad(set_to_none=True)
for param in self._fp32_optim_model_params:
param.grad = None

# Reset main grads
if self.contiguous_grad_buffer:
for param in self.parameters():
with _disable_pre_forward_hook(param):
param.main_grad = self.grad_buffer_view(param)

def grad_norm(self, parameters=None, norm_type=2.0, force=False):
assert norm_type == 2

# Compute grad norm
if force or self._grad_norm is None:

# Identify params for explicit FP32 optimizer
if self._fp32_optim is not None:
if parameters is None:
fp32_optim_params = self._fp32_optim_model_params
else:
fp32_optim_params = [
param for param in parameters
if param in self._fp32_optim_model_params
]
parameters = [
param for param in parameters
if param not in self._fp32_optim_model_params
]

# Compute norm of local gradients for distributed optimizer
grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type)
if self.redundant_size > 1:
grad_norm_sq /= self.redundant_size

# Compute norm of local gradients for explicit FP32 optimizer
if self._fp32_optim is not None:
_fp32_optim_grad_sync()
for model_param in fp32_optim_params:
i = self._fp32_optim_model_params.index(model_param)
main_param = self._fp32_optim_main_params[i]
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size

# Sum over all procs to get grad norm
torch.distributed.all_reduce(
grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
)
self._grad_norm = grad_norm_sq.sqrt()

# Use cached grad norm
return super().grad_norm()

def step(self, closure=None, *, grad_scaler=None):

# Apply distributed optimizer
loss = super().step(closure=closure, grad_scaler=grad_scaler)

if self._fp32_optim is not None:

# Handle grad scaling
if grad_scaler is not None:
scaler_state = grad_scaler._per_optimizer_states[id(self)]
for _, found_inf in scaler_state['found_inf_per_device'].items():
if found_inf.item():
return

# Apply explicit FP32 optimizer
self._fp32_optim_grad_sync()
for main_param in self._fp32_optim_main_params:
main_param.grad *= self._grad_scale
self._fp32_optim.step()
for model_param, main_param in zip(self._fp32_optim_model_params, self._fp32_optim_main_params):
main_param.grad = None
model_param.copy_(main_param.detach())

return loss

def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
if self._fp32_optim is not None and state_dict is not None:
state_dict['fp32_optim'] = self._fp32_optim.state_dict()
return state_dict

def load_state_dict(self, state_dict):
if self._fp32_optim is not None and 'fp32_optim' in state_dict:
self._fp32_optim.load_state_dict(state_dict['fp32_optim'])
del state_dict['fp32_optim']
return super().load_state_dict(state_dict)

0 comments on commit 59429ce

Please sign in to comment.