Skip to content

Commit

Permalink
NeVA Minor Fixes (#9608)
Browse files Browse the repository at this point in the history
* fix neva resume with empty param loaded for some pp stage

Signed-off-by: yaoyu-33 <[email protected]>

* fix crop size check

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
Signed-off-by: Tugrul Konuk <[email protected]>
  • Loading branch information
2 people authored and ertkonuk committed Jul 19, 2024
1 parent 82c529f commit e79f049
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ def create_image_processor(mm_cfg):
else:
raise (ValueError("Currently only support CLIPImageProcessor and SiglipImageProcessor from Huggingface"))

crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224))
if hasattr(image_processor, 'crop_size'):
crop_size = mm_cfg.vision_encoder.get("crop_size")
if hasattr(image_processor, 'crop_size') and crop_size is not None:
assert crop_size == (
image_processor.crop_size['height'],
image_processor.crop_size['width'],
Expand Down
15 changes: 11 additions & 4 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def zero(self):
self.data.zero_()

def allreduce_buffer(self):
"""Synchronous buffer data allreduce """
"""Synchronous buffer data allreduce"""
self.data.div_(get_data_parallel_world_size())
torch.distributed.all_reduce(self.data, group=self._data_group)

Expand Down Expand Up @@ -175,7 +175,7 @@ class MainParamsOptimizerWrapper(torch.optim.Optimizer):
Arguments:
optimizer: base optimizer such as Adam or SGD.
fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce.
contiguous_grad_bucket: to enable allocating the master gradients in the
contiguous_grad_bucket: to enable allocating the master gradients in the
contiguous memory space to reduce memory fragmentation.
async_grad_allreduce: enable asynchronous gradient allreduce that is executed
along with the training step backprop.
Expand Down Expand Up @@ -339,6 +339,7 @@ def __init__(

def _make_param_hook(self, param, main_param, i, grad_chunk_info, is_expert_group):
"""Create the grad accumulation and all-reduce hook for backprop."""

# Hook used for back-prop.
def param_hook(*unused):
# Accumulates gradients on main gradients
Expand All @@ -361,7 +362,9 @@ def allreduce_grads(use_fused_div, tensor, data_group, grad_mult):
else:
tensor.div_(grad_mult)
torch.distributed.all_reduce(
tensor, group=data_group, async_op=True,
tensor,
group=data_group,
async_op=True,
)

# Asynchronous gradients allreduce accross data_parallel ranks
Expand Down Expand Up @@ -473,12 +476,16 @@ def load_state_dict(self, state_dict):
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
logging.info('***WARNING*** loading optimizer from ' 'an old checkpoint ...')
if 'state' not in state_dict[optimizer_key]:
state_dict[optimizer_key]['state'] = {}
self.optimizer.load_state_dict(state_dict[optimizer_key])

# Copy data for the main params.
fp32_from_float16_params_key = 'fp32_from_fp16_params'
if fp32_from_float16_params_key not in state_dict:
fp32_from_float16_params_key = 'fp32_from_fp16'
if fp32_from_float16_params_key not in state_dict:
state_dict[fp32_from_float16_params_key] = []
for current_group, saved_group in zip(self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
Expand All @@ -489,7 +496,7 @@ def allreduce_main_grads(self):

@contextmanager
def no_sync(self):
""" A context manager to disable gradient synchronizations across
"""A context manager to disable gradient synchronizations across
data-parallel ranks."""
old_require_backward_grad_sync = self._require_backward_grad_sync
self._require_backward_grad_sync = False
Expand Down

0 comments on commit e79f049

Please sign in to comment.