Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samyamr/inference hook fix #851

Merged
merged 13 commits into from
Mar 15, 2021
8 changes: 6 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,8 +807,12 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
if start < param.ds_numel:
elements = min(param.ds_numel - start, partition_size)

dest_tensor = partition_buffer.view(-1).narrow(0, 0, elements)
dest_tensor_full_buffer = partition_buffer.view(-1).narrow(
0,
0,
partition_size)

dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
src_tensor = param.grad.view(-1).narrow(0, start, elements)

# just copy the grad partition to the buffer
Expand Down Expand Up @@ -841,7 +845,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
# elements))

#print("after partition gradients")
param.grad.data = dest_tensor.data
param.grad.data = dest_tensor_full_buffer.data
see_memory_usage("After partitioning gradients", force=False)


Expand Down
21 changes: 15 additions & 6 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,10 +961,9 @@ def _create_fp16_partitions_with_defragmentation(self):

#create flat buffer in CPU and move to GPU
self.fp16_partitioned_groups_flat.append(
flatten_dense_tensors_aligned(
self.fp16_partitioned_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
flatten_dense_tensors_aligned(self.fp16_partitioned_groups[i],
1).cuda(
torch.cuda.current_device()))
see_memory_usage(
f"After flattening and moving param group {i} to GPU",
force=False)
Expand All @@ -976,10 +975,12 @@ def _create_fp16_partitions_with_defragmentation(self):
flat_offset,
total_elements)
self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])
flat_offset += total_elements

# move param to flat buffer for both param offload on/off
self._move_to_flat_buffer(self.fp16_partitioned_groups[i],
self.fp16_partitioned_groups_flat[i])

see_memory_usage(f"After Flattening param group {i}", force=False)

def _create_fp32_partitions(self):
Expand Down Expand Up @@ -1036,6 +1037,14 @@ def setup_zero_stage3_hooks(self):
self.hierarchy = 0
self._register_hooks_recursively(self.module)

#reset step if in inference mode
def _end_of_forward_hook(module, *args):

if not torch._C.is_grad_enabled():
self.param_coordinator.reset_step()

self.module.register_forward_hook(_end_of_forward_hook)

def persistent_parameters(self):
persistent_params = []
total_persistent_parameters = 0
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,6 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")

if zero_stage == 3:
pytest.skip("skip for now")

config_dict = {
"train_batch_size": 4,
"steps_per_print": 1,
Expand All @@ -371,8 +368,9 @@ def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload):
args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=2)
def _test_zero_static_scale(args, zero_stage):
hidden_dim = 10
def _test_zero_static_scale(args, zero_stage, hidden_dim):
#making hidden size not divisible by DP for covering this scenario
hidden_dim = hidden_dim
model = SimpleModel(hidden_dim)

model, optim, _, _ = deepspeed.initialize(args=args,
Expand All @@ -393,7 +391,10 @@ def _test_zero_static_scale(args, zero_stage):
model.backward(loss)
model.step()

_test_zero_static_scale(args=args, zero_stage=zero_stage)
#test when hidden_dim is not aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=9)
#test when hidden_dim is aligned with world size
_test_zero_static_scale(args=args, zero_stage=zero_stage, hidden_dim=10)


def test_zero_static_scale_deprecated_format(tmpdir):
Expand Down