Skip to content

Commit

Permalink
Fix typo: LocalNonpersitentObject -> LocalNonpersistentObject
Browse files Browse the repository at this point in the history
Signed-off-by: Ananth Subramaniam <[email protected]>
  • Loading branch information
ananthsub committed Dec 11, 2024
1 parent 0500d6b commit 328f0db
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/source/checkpoints/dist_ckpt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ A sharded state dict is a (possibly nested) Python dictionary or list with the f
a. ShardedTensor
b. ShardedObject
c. ShardedTensorFactory
2. LocalNonpersitentObject
2. LocalNonpersistentObject
3. Arbitrary object


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
try:
from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedObject
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand All @@ -112,7 +112,7 @@

def skip_fp8_load(x):
if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key:
x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt
x = LocalNonpersistentObject(x.data) # use the FP8 state from initialization, not from ckpt
return x


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedObject
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads

Expand Down Expand Up @@ -2026,7 +2026,7 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:
# WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention
def skip_fp8_load(x):
if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key:
x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt
x = LocalNonpersistentObject(x.data) # use the FP8 state from initialization, not from ckpt
return x

if self.cfg.get('skip_fp8_attention_checkpoint_load', True):
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
try:
from megatron.core import dist_checkpointing, parallel_state
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject
from megatron.core.dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
Expand Down Expand Up @@ -515,7 +515,7 @@ def _fix_param_groups(
)
if expert_index:
# Temporary empty params so that loading doesn't fail
model_param_groups.insert(expert_index, {'params': LocalNonpersitentObject([]), 'is_expert': True})
model_param_groups.insert(expert_index, {'params': LocalNonpersistentObject([]), 'is_expert': True})
if 'optimizer' in sharded_state_dict['optimizer_states'][0]:
sharded_state_dict['optimizer_states'][0]['optimizer']['param_groups'] = model_param_groups
else:
Expand Down

0 comments on commit 328f0db

Please sign in to comment.