Skip to content

Commit

Permalink
Update docstring for get_model_state_dict (mosaicml#3318)
Browse files Browse the repository at this point in the history
Turns out it's empty dict for nonzero ranks for unsharded state dicts because for torch 2.1.2 we set the `FullStateDictConfig` `rank0_only` flag to `True` and for torch >2.1.2, the `dcp.get_model_state_dict` function always has empty dict for nonzero ranks for unsharded state dicts
  • Loading branch information
eracah authored May 24, 2024
1 parent 179d64f commit 47dbf9f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_model_state_dict(
Args:
model: The model to get the state dict from.
sharded_state_dict: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards.
If False, then rank 0 returns the state dict of the entire model and the other ranks return a dict of their shards. Default is False.
If False, then rank 0 returns the state dict of the entire model and the other ranks return an empty dict. Default is False.
precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype.
include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None.
ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.
Expand Down

0 comments on commit 47dbf9f

Please sign in to comment.