Skip to content

Commit

Permalink
improved readability + typos (#895)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
3 people authored Apr 7, 2021
1 parent a128f34 commit 5ca86ae
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 29 deletions.
32 changes: 12 additions & 20 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,27 +1662,19 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
# then instead just returns None.
self._curr_ckpt_path = os.path.join(save_dir, tag)

state = {
'module':
self.module_state_dict(),
'optimizer':
self.optimizer.state_dict()
state = dict(
module=self.module_state_dict(),
optimizer=self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None,
'lr_scheduler':
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
'csr_tensor_module_names':
self.csr_tensor_module_names,
'skipped_steps':
self.skipped_steps,
'global_steps':
self.global_steps,
'global_samples':
self.global_samples,
'dp_world_size':
self.dp_world_size,
'mp_world_size':
self.mp_world_size
}
lr_scheduler=self.lr_scheduler.state_dict()
if self.lr_scheduler is not None else None,
csr_tensor_module_names=self.csr_tensor_module_names,
skipped_steps=self.skipped_steps,
global_steps=self.global_steps,
global_samples=self.global_samples,
dp_world_size=self.dp_world_size,
mp_world_size=self.mp_world_size,
)
state.update(client_state)

log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
Expand Down
6 changes: 0 additions & 6 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def __init__(self, param_dict):
self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None

#Stage3 Specific Parameters
self.prefetch_bucket_size = None
self.param_persistence_threshold = None
self.max_live_parameters = None
self.max_reuse_distance = None

if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
if type(zero_config_dict) is bool:
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,10 @@ def _convert_to_deepspeed_param(self, param):
# Stores the shape of the original tensor
param.ds_shape = param.shape

# Stores the number of elements in the original parmaeter without padding
# Stores the number of elements in the original parameter without padding
param.ds_numel = param.numel()

# Stores the paritioned copy of the tensor
# Stores the partitioned copy of the tensor
param.ds_tensor = None

# Keeps track of how many active sub-modules need this param at any given point in time
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(self,
gradient_accumulation_steps=1,
elastic_checkpoint=False):

see_memory_usage("Stage 3 intialize beginning", force=True)
see_memory_usage("Stage 3 initialize beginning", force=True)

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
Expand Down

0 comments on commit 5ca86ae

Please sign in to comment.