Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] new zero functionality and install #909

Merged
merged 3 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,8 @@ def _zero3_consolidated_fp16_state_dict(self):

Get a full non-partitioned state_dict with fp16 weights on cpu.

Important: this function must be called on all ranks and not just rank 0.

This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:

1. consolidates the weights from different partitions on gpu0
Expand Down
8 changes: 7 additions & 1 deletion docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ Enabling and configuring ZeRO memory optimizations
"stage3_prefetch_bucket_size" : 5e8,
"stage3_param_persistence_threshold" : 1e6,
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false]
"elastic_checkpoint" : [true|false],
"stage3_gather_fp16_weights_on_model_save": [true|false]
}
```

Expand Down Expand Up @@ -351,6 +352,11 @@ Enabling and configuring ZeRO memory optimizations
| Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` |


***stage3_gather_fp16_weights_on_model_save***: [boolean]
| Description | Default |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gather the weights when this option is enabled and then saves the fp16 model weights. | `False` |

### Logging

***steps\_per\_print***: [integer]
Expand Down
12 changes: 12 additions & 0 deletions docs/_tutorials/advanced-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option

This should complete the full build 2-3 times faster. You can adjust `-j` to specify how many cpu-cores are to be used during the build. In the example it is set to 8 cores.

You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, pytorch, python, etc.)

```bash
DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel
```

This will create a pypi binary wheel under `dist`, e.g., ``dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`` and then you can install it directly on multiple machines, in our example:

```bash
pip install dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl
```


## Install DeepSpeed from source

Expand Down
38 changes: 38 additions & 0 deletions docs/_tutorials/zero.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,5 +260,43 @@ for more details.
self.init_method(self.position_embeddings.weight)
```

## Extracting weights

If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights:

- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`.
- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable:

```
"zero_optimization": {
"stage3_gather_fp16_weights_on_model_save": true
},
```
And then save the model using:

```
if self.deepspeed:
self.deepspeed.save_fp16_model(output_dir, output_file)
```

Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed.

Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them.
You can use this method to save ZeRO-2 weights as well.

If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage:

```
$ cd /path/to/checkpoints_dir
$ ./zero_to_fp32.py global_step1 pytorch_model.bin
Processing zero checkpoint at global_step1
Detected checkpoint of type zero stage 3, world_size: 2
Saving fp32 state dict to pytorch_model.bin (total_numel=60506624)
```

The `zero_to_fp32.py` gets created automatically when you save a checkpoint.

Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint.


Congratulations! You have completed the ZeRO tutorial.
8 changes: 8 additions & 0 deletions docs/code-docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ Optimizer Step
Gradient Accumulation
---------------------
.. autofunction:: deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary


Model Saving
------------
.. autofunction:: deepspeed.DeepSpeedEngine.save_fp16_model


Additionally when a DeepSpeed checkpoint is created, a script ``zero_to_fp32.py`` is added there which can be used to reconstruct fp32 master weights into a single pytorch ``state_dict`` file.