From 856ab3171dc32d72bfdbcd3a3607fe255f3ae6b5 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 19 Mar 2021 22:02:28 +0000 Subject: [PATCH 1/2] Save ZeRO3 (partitioned) fp16 weights --- deepspeed/runtime/engine.py | 8 +++++++- deepspeed/runtime/zero/stage3.py | 11 +++++++++++ docs/_tutorials/getting-started.md | 4 ++-- tests/unit/test_checkpointing.py | 8 ++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f965eb688d16..d8ad4a59f00b 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1342,10 +1342,15 @@ def all_gather_scalar(self, value): def module_state_dict(self, destination=None, prefix='', keep_vars=False): sd = self.module.state_dict(destination, prefix, keep_vars) + if self.zero_optimization_partition_weights(): + sd = self.optimizer.save_partitioned_weights(sd) return sd def load_module_state_dict(self, state_dict, strict=True): - self.module.load_state_dict(state_dict, strict=strict) + if self.zero_optimization_partition_weights(): + self.optimizer.load_partitioned_weights(state_dict) + else: + self.module.load_state_dict(state_dict, strict=strict) def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank): filename = 'zero_pp_rank_{}'.format(dp_rank) @@ -1445,6 +1450,7 @@ def _load_checkpoint(self, self.load_module_state_dict(state_dict=checkpoint['module'], strict=load_module_strict) + if self.optimizer is not None and not self.zero_optimization(): if self.fp16_enabled(): self.optimizer.load_state_dict( diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ea4653578616..8506d8b6b8d0 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2832,6 +2832,17 @@ def save_checkpoint_prologue(self): def save_checkpoint_epilogue(self): self.persistent_parameters[0].all_gather(self.persistent_parameters) + def save_partitioned_weights(self, state_dict): + for name, param in self.module.named_parameters(): + if name in state_dict.keys(): + state_dict[name] = param.ds_tensor + return state_dict + + def load_partitioned_weights(self, state_dict): + for name, param in self.module.named_parameters(): + if name in state_dict.keys(): + param.ds_tensor.copy_(state_dict[name]) + def _handle_overflow(cpu_sum, x, i): import math diff --git a/docs/_tutorials/getting-started.md b/docs/_tutorials/getting-started.md index e12388aaf973..e9b9aa0e627e 100644 --- a/docs/_tutorials/getting-started.md +++ b/docs/_tutorials/getting-started.md @@ -265,8 +265,8 @@ local machine to discover the number of slots available. The `--include` and `--exclude` arguments work as normal, but the user should specify 'localhost' as the hostname. -Also note that `CUDA_VISIBLE_DEVICES` can't be used with DeepSpeed to control -which devices should be used. For example, to use only gpu1 of the current +Also note that `CUDA_VISIBLE_DEVICES` can't be used with DeepSpeed to control +which devices should be used. For example, to use only gpu1 of the current node, do: ```bash deepspeed --include localhost:1 ... diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 765c44c8e551..bf72a0875287 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -41,6 +41,14 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): assert id(p0) != id(p1), f'Comparing fp16 model state tensor against itself : {id(p0)} <====> {id(p1)}' assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}" + # Compare ds_tensor values for ZeRO stage3 + for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()): + p0_has_ds_tensor = hasattr(p0, 'ds_tensor') + p1_has_ds_tensor = hasattr(p1, 'ds_tensor') + assert p0_has_ds_tensor == p1_has_ds_tensor, f'Mismatch has ds_tensor attribute p0:{p0_has_ds_tensor}, p1:{p1_has_ds_tensor}' + if p0_has_ds_tensor: + assert torch.allclose(p0, p1, atol=1e-07), f'FP16 model state {p0} is not equal to {p1}' + if not compare_optimizer: return From 8cd046faa3c8c3cd8f2601d52078029ad7cddbd1 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 19 Mar 2021 22:35:33 +0000 Subject: [PATCH 2/2] Compare ds_tensors --- tests/unit/test_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index bf72a0875287..ddc8204fc0ac 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -47,7 +47,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True): p1_has_ds_tensor = hasattr(p1, 'ds_tensor') assert p0_has_ds_tensor == p1_has_ds_tensor, f'Mismatch has ds_tensor attribute p0:{p0_has_ds_tensor}, p1:{p1_has_ds_tensor}' if p0_has_ds_tensor: - assert torch.allclose(p0, p1, atol=1e-07), f'FP16 model state {p0} is not equal to {p1}' + assert torch.allclose(p0.ds_tensor, p1.ds_tensor, atol=1e-07), f'FP16 model state {p0} is not equal to {p1}' if not compare_optimizer: return