diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 5417188466..e3b091c071 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -32,7 +32,7 @@ def get_model_state_dict( Args: model: The model to get the state dict from. sharded_state_dict: Whether the model state dict should be sharded or not. If True, every rank returns the state dict of its shards. - If False, then rank 0 returns the state dict of the entire model and the other ranks return a dict of their shards. Default is False. + If False, then rank 0 returns the state dict of the entire model and the other ranks return an empty dict. Default is False. precision: The precision of the model. Can be specified as a string ('fp32', 'fp16', 'bf16') or a torch.dtype. include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None.