From 4fd79205c6b85b47e00810143c69d342ce874ede Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Thu, 19 Dec 2024 19:26:50 +0200 Subject: [PATCH 01/13] Allow to compile collective for PT>2.3 (#6899) Allow to compile collective for PT>2.3 commit re-uploaded due to github CI issue originally uploaded by @nelyahu --- deepspeed/comm/torch.py | 50 +++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 988b74232bb9..5461ae18d1f0 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -20,6 +20,12 @@ DS_COMM_REDUCE_OFF = False +def disable_compiler_collective(func): + if required_torch_version(min_version=2.3): + return func + return compiler.disable(func) + + def build_shm_op(): builder = get_accelerator().create_op_builder("ShareMemCommBuilder") if builder is None or not deepspeed.ops.__compatible_ops__[builder.NAME]: @@ -114,7 +120,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name=' self.shm_comm_op.initialize(self.get_world_size(), self.get_rank()) @classmethod - @compiler.disable + @disable_compiler_collective def get_all_gather_function(self): if hasattr(torch.distributed, "all_gather_into_tensor"): return torch.distributed.all_gather_into_tensor @@ -123,7 +129,7 @@ def get_all_gather_function(self): return None @classmethod - @compiler.disable + @disable_compiler_collective def get_reduce_scatter_function(self): if hasattr(torch.distributed, "reduce_scatter_tensor"): return torch.distributed.reduce_scatter_tensor @@ -146,7 +152,7 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' - @compiler.disable + @disable_compiler_collective def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): op = self._reduce_op(op) return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) @@ -158,7 +164,7 @@ def inference_all_reduce(self, tensor, op, group=None): else: return torch.ops.deepspeed.inference_all_reduce_(tensor) - @compiler.disable + @disable_compiler_collective def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): """ proxy func to torch.distributed.all_reduce_coalesced, which is included in PyTorch 1.13 and above @@ -169,7 +175,7 @@ def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group op = self._reduce_op(op) return torch.distributed.all_reduce_coalesced(tensors=tensors, op=op, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -177,7 +183,7 @@ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): return Noop() return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False): if DS_COMM_REDUCE_SCATTER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -190,7 +196,7 @@ def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def broadcast(self, tensor, src, group=None, async_op=False): if DS_COMM_BROADCAST_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -199,7 +205,7 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -208,7 +214,7 @@ def all_gather(self, tensor_list, tensor, group=None, async_op=False): else: return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False): if self.has_all_gather_into_tensor(): return self.all_gather_function(output_tensor=output_tensor, @@ -216,7 +222,7 @@ def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_ group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: if int(os.getenv('RANK', '0')) == 0: @@ -234,7 +240,7 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_op=False): """""" assert len(output_tensors) == len(input_tensors), "" @@ -258,7 +264,7 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_ else: reqs[-1].wait() - @compiler.disable + @disable_compiler_collective def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False): if self.has_reduce_scatter_tensor(): return self.reduce_scatter_function(output_tensor, @@ -272,7 +278,7 @@ def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, gr "please consider upgrading your pytorch installation.") pass - @compiler.disable + @disable_compiler_collective def all_to_all_single(self, output, input, @@ -287,27 +293,27 @@ def all_to_all_single(self, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def all_to_all(self, output_tensor_list, input_tensor_list, group=None, async_op=False): return torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def send(self, tensor, dst, group=None, tag=0): return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def recv(self, tensor, src=None, group=None, tag=0): return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def isend(self, tensor, dst, group=None, tag=0): return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def irecv(self, tensor, src=None, group=None, tag=0): return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag) - @compiler.disable + @disable_compiler_collective def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): return torch.distributed.gather(tensor=tensor, gather_list=gather_list, @@ -315,7 +321,7 @@ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): return torch.distributed.scatter(tensor=tensor, scatter_list=scatter_list, @@ -323,13 +329,13 @@ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False): group=group, async_op=async_op) - @compiler.disable + @disable_compiler_collective def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None): if group is None: group = torch.distributed.GroupMember.WORLD return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids) - @compiler.disable + @disable_compiler_collective def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False): if group is None: group = torch.distributed.GroupMember.WORLD From 00ea0c46c2296db158d10497602f9832c4445d84 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Fri, 20 Dec 2024 02:54:45 +0200 Subject: [PATCH 02/13] Zero2: avoid graph breaks in torch.compile by using param_idx (#6803) inside reduce_independent_p_g_buckets_and_remove_grads and in reduce_ipg_grads which are being executed during the BWD hook in zero2, the model param is being stored inside params_in_ipg_bucket. torch.compile has hard time tracing parameters. By using the param's static index inside the group the same logic can be maintain with less complexity. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams --- deepspeed/runtime/zero/stage_1_and_2.py | 9 ++++++--- tests/unit/moe/test_moe.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 7ac89a233808..ecb2a527f870 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -310,6 +310,7 @@ def __init__(self, for param in param_group['params']: if param.requires_grad: param.grad_accum = None + param.param_idx_in_group = len(trainable_parameters) trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" self.grads_in_ipg_bucket.append(grad_reduc) - self.params_in_ipg_bucket.append((i, param, param_id)) + self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for _, param, param_id in self.params_in_ipg_bucket: + for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[group_idx][param_idx_in_group] assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 9ee546437f6c..c67a907c6785 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -93,7 +93,8 @@ def strict_average_tensor(tensor): process_group = optimizer.dp_process_group curr_size = 0 pg_offsets = [] - for i, param, param_id in optimizer.params_in_ipg_bucket: + for i, param_idx, param_id in optimizer.params_in_ipg_bucket: + param = optimizer.bit16_groups[i][param_idx] process_group = optimizer.dp_process_group if optimizer.ipg_bucket_has_moe_params: process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param( From eea5304807c6a04d0f2c55cb935ec295235d9b54 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:13:46 +0200 Subject: [PATCH 03/13] hpu_accelerator: use torch.use_deterministic_algorithms (#6897) formal API instead of hpu.setDeterministic --- accelerator/hpu_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 723a66e4c6fb..b46351f8ca43 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -21,8 +21,8 @@ def __init__(self): self.apply_hpu_workarounds() try: import habana_frameworks.torch.hpu as hpu - hpu.setDeterministic(True) self.hpu = hpu + torch.use_deterministic_algorithms(True) except ImportError as e: raise ValueError( f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") From 85cc5f9bb3f0175a2d13ea1ed65bf7d202b7f0d9 Mon Sep 17 00:00:00 2001 From: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Date: Thu, 26 Dec 2024 09:12:04 -0800 Subject: [PATCH 04/13] Fix error caused by all_reduce call in domino (#6880) Fix #6851 Initialize communication backend to fix error caused by all_reduce call in the Domino transformer layer. Verified correctness in local test. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/domino/transformer.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 8eb95e49c29d..88c5494c8147 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,8 +6,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed -from deepspeed import comm as dist +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator @@ -97,7 +96,7 @@ def backward(ctx, grad_output): return grad_output # Async All-reduce. - handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) + handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) ctx.handle_dic[ctx.h_id] = handle return None, grad_output, None, None @@ -249,6 +248,10 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm is not initialized!" + self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type @@ -358,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): layernorm_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = deepspeed.comm.all_reduce(attention_output0, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) attention_output1, attention_bias1 = \ self.self_attention( layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle1 = deepspeed.comm.all_reduce(attention_output1, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle0.wait() # Residual0 connection. @@ -413,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): output0 = output0 + bias_c output0 = self.mlp_activation_func(output0) output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle1.wait() @@ -425,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if bias_c is not None: output1 = output1 + bias_c output1 = torch.matmul(output1, self.weight_r.t()) - deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) + dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) handle2.wait() From cc03c76d57f41752d8cfb84c2e45b8e0da8083da Mon Sep 17 00:00:00 2001 From: Raza Sikander Date: Fri, 27 Dec 2024 01:37:28 +0530 Subject: [PATCH 05/13] Update Gaudi2 jobs to latest 1.19 build (#6905) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .github/workflows/hpu-gaudi2-nightly.yml | 2 +- .github/workflows/hpu-gaudi2.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml index 5c5caff1ebb0..c0576360cd61 100644 --- a/.github/workflows/hpu-gaudi2-nightly.yml +++ b/.github/workflows/hpu-gaudi2-nightly.yml @@ -21,7 +21,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index a06f871b7c56..b8b6f3cb5502 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest + image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice From 3573858e7ce2c723b8c43231c6c6b0cf97dca2fc Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Mon, 30 Dec 2024 20:53:41 +0200 Subject: [PATCH 06/13] Change compile for pipeline module torch.compile (#6478) We have encountered and issue with torch.compile and the pipeline module. modifying a member of the module (micro_offset) during the forward function will cause torch compile to restart the analysis and treat the module as dynamic. In order to bypass this issue without significantly changing the way the pipeline module works we propose to compile only the layers in the pipeline module instead of the forward function of pipeline module. this will bypass the issue and should still give most of the benefit of torch compiling the pipeline module while avoiding the issue. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/pipe/module.py | 8 ++++++++ tests/unit/pipe/test_pipe_module.py | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31fec30be788..9fbd91f750a9 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -662,3 +662,11 @@ def get_additional_losses(self): Return a dictionary of {"loss name": loss_value} or None if no additional losses. """ return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 05c6a82ef55a..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -60,9 +60,12 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True @pytest.mark.parametrize("activation_checkpoints", [False, True]) - def test(self, sequential_model, simple_config, batch_input, activation_checkpoints): + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name()) From 456c9ac67975da698e44dfd4f90c4f7b867d08bd Mon Sep 17 00:00:00 2001 From: Max Kovalenko Date: Fri, 3 Jan 2025 17:48:24 +0200 Subject: [PATCH 07/13] Stage3: Use new torch grad accumulation hooks API (#6773) * This commit addresses a Deepspeed issue [#6718](https://github.com/microsoft/DeepSpeed/issues/6718) * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as `param.data = replicated_tensor.data` used in `allgather_params(..)` are compiled into `param.set()` causing the hook assigned to the grad_acc node not being called. * Starting from PyTorch 2.1 there is a new and robust hook API on a param itself: `param.register_post_accumulate_grad_hook(..)` * This commit will make use of the proper API depending on the PyTorch version * It will also disable compile for PyTorch versions < 2.1 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- deepspeed/runtime/compiler.py | 3 ++- deepspeed/runtime/zero/stage3.py | 7 ++----- deepspeed/utils/torch.py | 9 +++++++++ 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index fa9220f4fcd0..be778b83f8bb 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.utils.torch import required_torch_version try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,7 +17,7 @@ def is_compile_supported(): - return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") + return required_torch_version(min_version=2.1) def disable(func): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 459cffce52c8..28f91cb9b3ab 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,6 +16,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item @@ -1159,7 +1160,6 @@ def overlapping_partition_gradients_reduce_epilogue(self): def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] self.leaf_parameters = defaultdict(list) for i, param_group in enumerate(self.fp16_groups): for param in param_group: @@ -1172,15 +1172,12 @@ def create_reduce_and_remove_grad_hooks(self): #print(f"After all gather {param.device}, {param.shape}") def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) #print(f"param grad fn {param.expand_as(param).grad_fn}") if z3_leaf_parameter(param): diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py index eb22d3561035..1d32775fe64a 100644 --- a/deepspeed/utils/torch.py +++ b/deepspeed/utils/torch.py @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def register_grad_hook(param, hook): + if required_torch_version(min_version=2.1): + return param.register_post_accumulate_grad_hook(hook) + else: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + return grad_acc.register_hook(hook) From a8ede3a9df556a5d6beb22a4bf38fa9852b3bada Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Fri, 3 Jan 2025 08:25:50 -0800 Subject: [PATCH 08/13] Cleanup ops/transformer/inference tests (#6830) --- tests/unit/ops/transformer/inference/test_bias_add.py | 2 -- tests/unit/ops/transformer/inference/test_bias_gelu.py | 4 ++-- tests/unit/ops/transformer/inference/test_matmul.py | 2 -- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py index f25bbc1be692..eb283924f73c 100644 --- a/tests/unit/ops/transformer/inference/test_bias_add.py +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_add_reference(activations, bias): return activations + bias diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index e3a3bad63961..f0a09245e890 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -10,8 +10,8 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version from .inference_test_utils import allclose, get_dtypes -from packaging import version as pkg_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias): @pytest.mark.parametrize("channels", [512, 1232, 4096]) @pytest.mark.parametrize("dtype", get_dtypes()) def test_bias_gelu(batch, sequence, channels, dtype): - if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"): + if not required_torch_version(min_version=1.12): pytest.skip("gelu implementation matches only after torch 1.12") activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name()) diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 2ab195ee0115..6f5173bbc827 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -inference_module = None - def allclose(x, y): assert x.dtype == y.dtype From 0dbbb70b99f7f251996128115a53c1c8397efa8f Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Fri, 3 Jan 2025 21:57:49 -0800 Subject: [PATCH 09/13] Fix `checkpointable_layers` Logic (#6881) **Problem** There's an edge-case in DeepSpeed, where if all three of the following are true: 1. Deepspeed activation checkpointing is applied 2. The user passes `checkpointable_layers` (e.g. https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175) 3. The user's model class contains `GPT2ModelPipe` or GPTModelPipe` Then the `checkpointable_layers` will not be activation checkpointed. **Reason** This is because in the current logic, `_is_checkpointable` will short-circuit to just return layers matching `ParallelTransformerLayerPipe` in the case of `self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe')`. See https://github.com/microsoft/DeepSpeed/blob/da771ed42e41a44d5047813ca4672f1cfe9d1731/deepspeed/runtime/pipe/module.py#L653 **Proposed Fixes** I think that `checkpointable_layers` should always be checked for, and added logic to this effect. I also found the documentation for `checkpointable_layers` confusing and contradictory, so I updated the docstring. Lastly, I added a unit test for `checkpointable_layers`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/pipe/module.py | 14 +++++- .../test_activation_checkpointing.py | 50 +++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 9fbd91f750a9..49fa2807c355 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -116,7 +116,9 @@ def forward(self, inputs): partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. - checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + checkpointable_layers (list[str], optional): List of layer class names that are eligible for checkpointing. For GPT models, + ParallelTransformerLayerPipe is always checkpointed regardless of this list. If None, all layers with parameters are + considered checkpointable. Defaults to None. dynamic_shape: Allows dynamic shapes of inputs. This might have a performance impact. """ @@ -650,9 +652,17 @@ def _is_checkpointable(self, funcs): # because only non_reentrant_checkpoint can accept inputs with requires_grad=False # otherwise, the backward of the embedding layer won't receive gradients. if self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe'): - return all('ParallelTransformerLayerPipe' in f.__class__.__name__ for f in funcs) + # For GPT models, checkpoint both transformer layers and any additional + # layers specified in checkpointable_layers (if provided) + return all('ParallelTransformerLayerPipe' in f.__class__.__name__ or ( + self.checkpointable_layers is not None and f.__class__.__name__ in self.checkpointable_layers) + for f in funcs) + if self.checkpointable_layers is not None: + # For non-GPT models, only checkpoint layers specified in checkpointable_layers return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs) + + # Default behavior: checkpoint any layer that has parameters params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index 22a61003b31e..dd3bcd7fb6bd 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -8,6 +8,7 @@ import pytest import torch import deepspeed +from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.accelerator import get_accelerator from copy import deepcopy from unit.common import DistributedTest @@ -259,3 +260,52 @@ def test_ckpt_non_tensor_output_ordering(self, non_tensor_output): else: ordering += [torch.is_tensor(non_tensor_output)] _test_activation_checkpoint_ordering(module, ordering, inputs) + + +class TestCheckpointableLayersConfig(DistributedTest): + world_size = 1 + + def test_gpt2_checkpointable_layers(self): + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU accelerator does not support this test yet") + + # Create a simple topology for testing + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=1, num_mp=1, num_dp=1) + + # Create test classes that we want to checkpoint + class TestTransformerLayer(torch.nn.Module): + + def forward(self, x): + return x + + class ParallelTransformerLayerPipe(TestTransformerLayer): + pass + + class GMLPBlock(TestTransformerLayer): + pass + + # Create a mock GPT2 model with different layer types + class TestGPT2ModelPipe(PipelineModule): + + def __init__(self): + self.layers_spec = [ + LayerSpec(ParallelTransformerLayerPipe), + LayerSpec(GMLPBlock), + LayerSpec(torch.nn.Linear, 10, 10), # Should not be checkpointed + ] + + super().__init__(layers=self.layers_spec, + topology=topo, + checkpointable_layers=["GMLPBlock", "ParallelTransformerLayerPipe"]) + + model = TestGPT2ModelPipe() + model.to(get_accelerator().device_name()) + + # Build layers manually for testing + layers = [spec.build() for spec in model.layers_spec] + + # Test that _is_checkpointable returns correct values + assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe + assert model._is_checkpointable([layers[1]]) == True # GMLPBlock + assert model._is_checkpointable([layers[2]]) == False # Linear layer From f8c9f314ffe7eddfbd3645a94143301e610f68de Mon Sep 17 00:00:00 2001 From: hj-wei Date: Tue, 7 Jan 2025 01:38:19 +0800 Subject: [PATCH 10/13] [BUG FIX]:fix get torch.version.cuda error when cuda is None in rocm (#6909) HI, I found some error when using deepspeed with rocm-torch ``` torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) ``` will raise an AttributeError when torch.version.cuda is None. This occurs because the CUDA version in rocm-torch/version.py is set to always be None, leading to potential runtime errors in environments where ROCm is being used. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- op_builder/builder.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 461281d4a569..ab26054bda7d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -415,10 +415,11 @@ def cpu_arch(self): return '-mcpu=native' return '-march=native' - def is_cuda_enable(self): + def get_cuda_compile_flag(self): try: - assert_no_cuda_mismatch(self.name) - return '-D__ENABLE_CUDA__' + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" except MissingCUDAException: print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " "only cpu ops can be compiled!") @@ -839,7 +840,7 @@ def cxx_args(self): CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() - CUDA_ENABLE = self.is_cuda_enable() + CUDA_ENABLE = self.get_cuda_compile_flag() args += [ CPU_ARCH, '-fopenmp', From c5e48f49d8368216b4f99ef4023d2855f1ce3983 Mon Sep 17 00:00:00 2001 From: Omar Elayan <142979319+oelayan7@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:54:57 +0200 Subject: [PATCH 11/13] Add fp8_gemm fallback for non-triton systems (#6916) - Removed try/except from __init__ file in fp_quantizer and added a single entry point instead - Renamed file fp8_gemm to fp8_gemm_triton, and the function matmul_fp8 to matmul_fp8_triton - Added a new entry point fp8_gemm with matmul_fp8 inside, and if the system supports triton it calls the triton implementation and if not it calls the fallback Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/ops/fp_quantizer/__init__.py | 7 +- deepspeed/ops/fp_quantizer/fp8_gemm.py | 163 +---------------- deepspeed/ops/fp_quantizer/fp8_gemm_triton.py | 171 ++++++++++++++++++ tests/unit/ops/fp_quantizer/test_fp8_gemm.py | 16 +- 4 files changed, 189 insertions(+), 168 deletions(-) create mode 100644 deepspeed/ops/fp_quantizer/fp8_gemm_triton.py diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 51377bc6092c..f9cf23373c26 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -4,9 +4,4 @@ # DeepSpeed Team from .quantize import FP_Quantize, Quantizer - -try: - import triton - from .fp8_gemm import matmul_fp8 -except ImportError: - pass +from .fp8_gemm import matmul_fp8 diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py index 55504e3af8c9..db4fa5ae2c92 100644 --- a/deepspeed/ops/fp_quantizer/fp8_gemm.py +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -11,161 +11,18 @@ ################################### import torch -import triton -import triton.language as tl -@triton.jit -def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m +def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer): + from deepspeed import get_accelerator - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + if not get_accelerator().is_triton_supported(): + return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer) + else: + # Import dynamically to prevent failures on systems without triton. + from .fp8_gemm_triton import matmul_fp8_triton + return matmul_fp8_triton(inp, weight, scale, quantization_group_size) - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> bf16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) - w = (w + 0x3C00).to(tl.uint16) - w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K - weight = tl.load(weight_data, mask=weight_mask, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), - mask=weight_mask, - other=0.0) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.bfloat16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -@triton.jit -def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> fp16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) - w = (w + 0x2000).to(tl.uint16) - w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - - weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.float16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -def matmul_fp8(inp, weight, scale, quantization_group_size): - - assert inp.shape[1] == weight.shape[0], \ - f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" - - M, K = inp.shape - K, N = weight.shape - - out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) - - # GEMM tuning parameters! - # TODO: Add a more configurable tuning for selecting the best GeMM - BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 - BLOCK_SIZE_N = 64 - BLOCK_SIZE_K = max(64, quantization_group_size) - GROUP_SIZE_M = 8 - num_stages = 4 - num_warps = 4 - if M >= 256: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = max(128, quantization_group_size) - num_stages = 3 - num_warps = 8 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 - kernel[grid](inp, - weight, - out, - scale, - M, - N, - K, - inp.stride(0), - inp.stride(1), - weight.stride(0), - weight.stride(1), - out.stride(0), - out.stride(1), - quantization_group_size=quantization_group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=GROUP_SIZE_M, - num_stages=num_stages, - num_warps=num_warps) - return out +def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer): + return torch.matmul(inp, quantizer.dequantize(weight, scale=scale)) diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py new file mode 100644 index 000000000000..746e217d4194 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> bf16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K + weight = tl.load(weight_data, mask=weight_mask, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), + mask=weight_mask, + other=0.0) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8_triton(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py index d66f7c8cb4cc..a4cf579f5943 100644 --- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -14,6 +14,8 @@ from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 +from deepspeed import get_accelerator + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8], ids=[ @@ -21,23 +23,19 @@ ]) @pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) def test_fp_quant(dtype, q_bits, M): + device_name = get_accelerator().device_name() quantization_group_size = 128 fpq = FP_Quantize(group_size=quantization_group_size) N = 8192 H = 4096 - x = torch.randn(M, H, dtype=dtype, device='cuda') - weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda') + x = torch.randn(M, H, dtype=dtype, device=device_name) + weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name) - weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True) + weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True) scale = fpq.get_scales() - out = matmul_fp8( - x, - weight, - scale, - quantization_group_size, - ) + out = matmul_fp8(x, weight, scale, quantization_group_size, fpq) out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale)) From b0040b6ca4799c34ebb7543e4edf6658505d9dc6 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 7 Jan 2025 04:06:06 +0800 Subject: [PATCH 12/13] Reduce the device bubble introduced by heavy loop synchronization in coalesced fetch/release(z3_leaf_module) (#6694) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit depend on https://github.com/microsoft/DeepSpeed/pull/6649 When performing fetch/release operations on Z3 leaf modules, the loop time is excessively long in fine-grained module. Compared to non-leaf modules, Z3 leaf modules may include a larger number of parameters. Although each loop unit does not consume much time, the overall loop length can be significant. ![image](https://github.com/user-attachments/assets/9891835a-2620-47f3-aba6-ea22b8905d1c) **The fetch time is impacted by:** Post-allgather operations (narrow, slice ,cat, difficult to avoid) Memory pressure(record_stream/fetch event create&sync) **The release time is impacted by:** slice Free parameter record_stream Considering the fine-grained leaf modules, where each parameter is relatively small, we can treat the parameters within each leaf module as a unified entity to handle memory pressure. This approach can approximately halve the CPU time required for fetch/release operations. --------- Co-authored-by: Ma, Guokai Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/mics.py | 2 +- deepspeed/runtime/zero/parameter_offload.py | 19 ++++--- .../runtime/zero/partition_parameters.py | 39 +++++++++------ .../zero/partitioned_param_coordinator.py | 50 ++++++++++++------- 4 files changed, 69 insertions(+), 41 deletions(-) diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index c9ae58a121de..628bf86a61da 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -38,7 +38,7 @@ class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle): def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None: super().__init__(allgather_handle, params, partitions, world_size) - def wait(self) -> None: + def wait(self, **kwargs) -> None: """ """ # let the current stream to op diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 0be88a1e1ba6..d5b7bac55146 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -145,6 +145,16 @@ def __init__( module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + self.fast_sharding_for_leaf_module = False + + if zero_module_granularity_threshold > 0: + self.min_granularity_value = sys.maxsize + self.min_granularity_layer = None + self.granularity_info = set() + self.z3_leaf_layers = [] + self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + self.fast_sharding_for_leaf_module = True + self.param_coordinator = PartitionedParameterCoordinator( prefetch_bucket_sz=self._prefetch_bucket_sz, max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, @@ -155,14 +165,7 @@ def __init__( timers=self.timers, zero_quantized_weights=self.zero_quantized_weights, zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - ) - - if zero_module_granularity_threshold > 0: - self.min_granularity_value = sys.maxsize - self.min_granularity_layer = None - self.granularity_info = set() - self.z3_leaf_layers = [] - self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module) self.forward_hooks = [] self.backward_hooks = [] diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index cb0cd7c8017d..e8cb797b8a5b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -55,7 +55,7 @@ def __init__(self, param: Parameter) -> None: non_blocking=True).view(param.ds_shape) self.__param = param - def wait(self) -> None: + def wait(self, **kwargs) -> None: if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().synchronize() self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -78,7 +78,7 @@ def __init__(self, params: List[Parameter]) -> None: non_blocking=True).view(param.ds_shape) @instrument_w_nvtx - def wait(self) -> None: + def wait(self, **kwargs) -> None: if self.__complete: return @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None: self.__param = param self.__quantization = quantization - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: instrument_w_nvtx(self.__handle.wait)() if self.__quantization: instrument_w_nvtx(self.__quantization.quant_handle.wait)() @@ -650,6 +650,8 @@ def wait(self) -> None: class AllGatherCoalescedHandle: + data_buffer = [] + def __init__( self, allgather_handle, @@ -672,7 +674,7 @@ def __init__( raise RuntimeError(f"expected param {param.ds_summary()} to not be available") @instrument_w_nvtx - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: if self.complete: return @@ -704,14 +706,20 @@ def wait(self) -> None: partitions.append(part_to_copy) param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) param.ds_status = ZeroParamStatus.AVAILABLE - - for part_to_copy in partitions: - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().is_synchronized_device() and handle_dependency: + for part_to_copy in partitions: part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += ds_tensor_numel self.complete = True + if not get_accelerator().is_synchronized_device() and not handle_dependency: + # if the device needs to handle dependencies and opts for explicit processing outside the function. + AllGatherCoalescedHandle.data_buffer.append(partitions) + + @staticmethod + def free_buffer(): + AllGatherCoalescedHandle.data_buffer = [] class MultipleAllGatherHandles: @@ -719,9 +727,9 @@ class MultipleAllGatherHandles: def __init__(self, handles: List[AllGatherCoalescedHandle]): self.handles = handles - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: for handle in self.handles: - handle.wait() + handle.wait(handle_dependency) class AllReduceCoalescedHandle: @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter], quantization=quant_info, ) - def partition(param_list=None, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) if param_list is None: param_list = [cls] - self._partition(param_list, has_been_updated=has_been_updated) + self._partition(param_list, has_been_updated=has_been_updated, free_data=True) def reduce_gradients_at_owner(param_list=None, hierarchy=0): cls = param @@ -1527,12 +1535,12 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): return handles - def _partition(self, param_list, force=False, has_been_updated=False): + def _partition(self, param_list, force=False, has_been_updated=False, free_data=True): for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: self._partition_param_sec(param) - self._partition_param(param, has_been_updated=has_been_updated) + self._partition_param(param, has_been_updated=has_been_updated, free_data=True) param.ds_status = ZeroParamStatus.NOT_AVAILABLE # if param.ds_tensor is not None: @@ -1540,7 +1548,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) @instrument_w_nvtx - def _partition_param(self, param, buffer=None, has_been_updated=False): + def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" global reuse_buffers print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - free_param(param) + if free_data: + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 596d0e9c20f9..08cb6c0de54f 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -76,18 +76,17 @@ class __ParamInTrace: param: Parameter step_id_last_used_at: int - def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: get_accelerator().Stream, - inflight_param_registry: InflightParamRegistry, - prefetch_nvme: bool = False, - timers=None, - zero_quantized_weights=False, - zero_quantized_nontrainable_weights=False, - ) -> None: + def __init__(self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, + prefetch_nvme: bool = False, + timers=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + fast_sharding_for_leaf_module=False) -> None: # mapping of param -> handle for each param that is currently in flight self.__inflight_param_registry = inflight_param_registry # keeps track of the number of submodules invoked so far. @@ -130,6 +129,10 @@ def __init__( self.__max_ongoing_fetch_events: int = 2 self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None) + # whether to enable fast fetch for the z3 leaf module. + # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure. + self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module + """Tracing and Tracking TODO. consider performing trace before initializing PartitionedParameterCoordinator and passing trace results into constructor. This way all the code in here can @@ -308,6 +311,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: wait_numel = 0 wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT self.__profiler.start_event(wait_event_name) + fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule) # wait for parameters in the immediately needed submodule to become available for param in params_to_fetch: param.ds_active_sub_modules.add(current_submodule.id) @@ -321,9 +325,9 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: self.__ongoing_fetch_events.popleft().synchronize() - self.__inflight_param_registry.pop(param).wait() + self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch) - if not get_accelerator().handles_memory_backpressure(): + if not get_accelerator().handles_memory_backpressure() and not fast_fetch: event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) @@ -331,6 +335,8 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if fast_fetch: + AllGatherCoalescedHandle.free_buffer() self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules @@ -412,10 +418,20 @@ def release_sub_module(self, submodule: Module) -> None: be released.""" params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) + + free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module + if not free_data: + # wait for the computation to finish and launch as early as possible. + empty_buffer = torch.empty(1, device=get_accelerator().current_device()) + for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) + self.__release_param(param, free_data) + if not free_data: + if param.ds_id in params_to_release and not param.is_external_param: + # empty buffer ensures that all computations are complete + param.data = empty_buffer @instrument_w_nvtx @torch.no_grad() @@ -490,11 +506,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: @compiler.disable @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: + def __release_param(self, param: Parameter, free_data: bool = True) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-release: {param.ds_summary()}") - param.partition() + param.partition(free_data=free_data) self.__n_available_params -= param.ds_numel @instrument_w_nvtx From c348c5b11a4fd3f70a53f5c9445f260472de47a8 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:35:50 -0800 Subject: [PATCH 13/13] Cleanup ops/transformer/inference tests (#6925) --- .../inference/inference_test_utils.py | 42 +++++++------------ .../transformer/inference/test_attention.py | 4 +- .../transformer/inference/test_layer_norm.py | 4 +- 3 files changed, 18 insertions(+), 32 deletions(-) diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index 9cfcae809f09..d63c51267e51 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from typing import Tuple + import torch from deepspeed.accelerator import get_accelerator @@ -23,38 +25,22 @@ def get_tolerances(): DTYPES = None -def get_dtypes(): +def get_dtypes(include_float=True): global DTYPES if DTYPES is None: - DTYPES = get_accelerator().supported_dtypes() + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass return DTYPES -def allclose(x, y): +def allclose(x, y, tolerances: Tuple[int, int] = None): assert x.dtype == y.dtype - rtol, atol = get_tolerances()[x.dtype] + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances return torch.allclose(x, y, rtol=rtol, atol=atol) - - -def assert_almost_equal(x, y, decimal=2, err_msg=''): - import numpy.testing as npt - if isinstance(x, torch.Tensor): - if x.dtype == torch.bfloat16: - x = x.float() - x = x.cpu().detach().numpy() - if isinstance(y, torch.Tensor): - if y.dtype == torch.bfloat16: - y = y.float() - y = y.cpu().detach().numpy() - npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) - - -def max_diff(a, b): - a = a.to(torch.float32).flatten() - b = b.to(torch.float32).flatten() - diff = torch.abs(a - b) - max_diff_indices = torch.argsort(diff)[-1] - print("Max difference indices:", max_diff_indices) - print("Max difference values:", diff[max_diff_indices]) - print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}") - return max_diff_indices diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py index ecf681542ff6..cae201d747a3 100644 --- a/tests/unit/ops/transformer/inference/test_attention.py +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -7,7 +7,7 @@ import torch import deepspeed from deepspeed.accelerator import get_accelerator -from .inference_test_utils import assert_almost_equal +from .inference_test_utils import allclose # reference timplementation @@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float use_triton_flash=False, use_ds_attention=False) tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) - assert_almost_equal(ref_out, tri_out) + assert (allclose(ref_out, tri_out)) diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index 7711daf0d887..4a84add16046 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -9,7 +9,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp -from .inference_test_utils import allclose, get_dtypes, assert_almost_equal +from .inference_test_utils import allclose, get_dtypes try: import triton # noqa: F401 # type: ignore from deepspeed.ops.transformer.inference.triton import ( @@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device=' y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias, eps).to(dtype) # compare - assert_almost_equal(y_tri, y_ref) + assert (allclose(y_tri, y_ref))