diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 05825fc90688..e6cb9199899a 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -807,8 +807,12 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): if start < param.ds_numel: elements = min(param.ds_numel - start, partition_size) - dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements) + dest_tensor_full_buffer = partition_buffer.view(-1).narrow( + 0, + 0, + partition_size) + dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) src_tensor = param.grad.view(-1).narrow(0, start, elements) # just copy the grad partition to the buffer @@ -841,7 +845,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False): # elements)) #print("after partition gradients") - param.grad.data = dest_tensor.data + param.grad.data = dest_tensor_full_buffer.data see_memory_usage("After partitioning gradients", force=False) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index f840de15c57d..99b4916aef3c 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -961,10 +961,9 @@ def _create_fp16_partitions_with_defragmentation(self): #create flat buffer in CPU and move to GPU self.fp16_partitioned_groups_flat.append( - flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) + flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i], + 1).cuda( + torch.cuda.current_device())) see_memory_usage( f"After flattening and moving param group {i} to GPU", force=False) @@ -976,10 +975,12 @@ def _create_fp16_partitions_with_defragmentation(self): flat_offset, total_elements) self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - self._move_to_flat_buffer(self.fp16_partitioned_groups[i], - self.fp16_partitioned_groups_flat[i]) flat_offset += total_elements + # move param to flat buffer for both param offload on/off + self._move_to_flat_buffer(self.fp16_partitioned_groups[i], + self.fp16_partitioned_groups_flat[i]) + see_memory_usage(f"After Flattening param group {i}", force=False) def _create_fp32_partitions(self): @@ -1036,6 +1037,14 @@ def setup_zero_stage3_hooks(self): self.hierarchy = 0 self._register_hooks_recursively(self.module) + #reset step if in inference mode + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + + self.module.register_forward_hook(_end_of_forward_hook) + def persistent_parameters(self): persistent_params = [] total_persistent_parameters = 0 diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 5012614f97b0..dbd40c322be9 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: pytest.skip("cpu-adam is not compatible") - if zero_stage == 3: - pytest.skip("skip for now") - config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -371,8 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): args = args_from_dict(tmpdir, config_dict) @distributed_test(world_size=2) - def _test_zero_static_scale(args, zero_stage): - hidden_dim = 10 + def _test_zero_static_scale(args, zero_stage, hidden_dim): + #making hidden size not divisible by DP for covering this scenario + hidden_dim = hidden_dim model = SimpleModel(hidden_dim) model, optim, _, _ = deepspeed.initialize(args=args, @@ -393,7 +391,10 @@ def _test_zero_static_scale(args, zero_stage): model.backward(loss) model.step() - _test_zero_static_scale(args=args, zero_stage=zero_stage) + #test when hidden_dim is not aligned with world size + _test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9) + #test when hidden_dim is aligned with world size + _test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10) def test_zero_static_scale_deprecated_format(tmpdir):