Skip to content

Commit

Permalink
add rank to logline for megatron strategy
Browse files Browse the repository at this point in the history
Signed-off-by: Ananth Subramaniam <[email protected]>
  • Loading branch information
ananthsub committed Jan 28, 2025
1 parent 2127dd3 commit c2f960f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_checkpoint(
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
end_time = time.monotonic()
logging.info(
f'Global Checkpoint Save: Start time : {start_time} s : Time spent in save_checkpoint: {end_time - start_time} s'
f'Global Checkpoint Save: Rank : {torch.distributed.get_rank()} : Start time : {start_time} s : Time spent in save_checkpoint: {end_time - start_time} s'
)

@override
Expand Down Expand Up @@ -268,7 +268,7 @@ def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]:
checkpoint = self.checkpoint_io.load_checkpoint(path, sharded_state_dict=sharded_state_dict)
end_time = time.monotonic()
logging.info(
f'Global Checkpoint Load: Start time : {start_time} s : Time spent in load_checkpoint: {end_time - start_time} s'
f'Global Checkpoint Load: Rank : {torch.distributed.get_rank()} : Start time : {start_time} s : Time spent in load_checkpoint: {end_time - start_time} s'
)
mcore_to_pyt_sharded_state_dict(checkpoint['sharded_state_dict'], msd)

Expand Down
11 changes: 9 additions & 2 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,11 +750,14 @@ def save_checkpoint(
if self.ckpt_save_optimizer:
checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()]

from nemo.utils import AppState

app_state = AppState()
start_time = time.monotonic()
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
end_time = time.monotonic()
logging.info(
f'Global Checkpoint Save: Start time : {start_time} s : Time spent in save_checkpoint: {end_time - start_time} s'
f'Global Checkpoint Save: Rank : {app_state.global_rank} : Start time : {start_time} s : Time spent in save_checkpoint: {end_time - start_time} s'
)

def should_restore_optimizer_states(self, selective_restore: bool = False) -> bool:
Expand Down Expand Up @@ -788,13 +791,17 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
self.lightning_module.strict_loading if self.ckpt_load_strictness is None else self.ckpt_load_strictness
)

from nemo.utils import AppState

app_state = AppState()

start_time = time.monotonic()
checkpoint = self.checkpoint_io.load_checkpoint(
checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict
)
end_time = time.monotonic()
logging.info(
f'Global Checkpoint Load: Start time : {start_time} s : Time spent in load_checkpoint: {end_time - start_time} s'
f'Global Checkpoint Load: Rank : {app_state.global_rank} : Start time : {start_time} s : Time spent in load_checkpoint: {end_time - start_time} s'
)

if selective_restore:
Expand Down

0 comments on commit c2f960f

Please sign in to comment.