From bd8d97d81c5570912e9dbd4fe126e2b63924290b Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Wed, 14 Aug 2024 00:59:04 -0700 Subject: [PATCH] Fix for issue: https://github.com/microsoft/onnxruntime-genai/issues/552 --- src/python/py/models/builder.py | 257 +++++++++++++++++++++----------- 1 file changed, 167 insertions(+), 90 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index fb6523b33..4be62d84c 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -550,6 +550,11 @@ def make_tile(self, name, inputs, dtype, shape): self.make_node("Tile", inputs=inputs, outputs=[output], name=name) self.make_value_info(output, dtype, shape=shape) + def make_trilu(self, name, inputs, upper, dtype, shape): + output = f"{name}/output_0" + self.make_node("Trilu", inputs=inputs, outputs=[output], name=name, upper=upper) + self.make_value_info(output, dtype, shape=shape) + def make_equal(self, name, inputs, shape): output = f"{name}/output_0" self.make_node("Equal", inputs=inputs, outputs=[output], name=name) @@ -660,7 +665,7 @@ def make_matmul_int4(self, matmul, basename, root_input, **kwargs): # print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.") # print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.") return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs) - + name = f"{basename}NBits" # Input weights are quantized, save quantized MatMul numpy weights for onnx model @@ -1608,7 +1613,7 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) + model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} @@ -1689,61 +1694,73 @@ def make_attention_mask_reformatting_for_mha(self): # Make nodes for the attention mask subgraphs that reformat the # 2D attention mask (B, S) to 4D causal attention mask (B, N, S, T) # - # input_ids past_key_values.0.key - # / \ | - # Shape Shape Shape - # | | | - # Gather Gather Gather - # (idx=0) (idx=1) (idx=2) - # | | |\ / - # | | | \ / - # | | | Add attention_mask--------+ - # | | | | / \ | - # Unsqueeze Unsqueeze | Unsqueeze Shape Shape | - # \ | | / | | | - # \ | +-/--------+----------+----------+ Gather Gather Unsqueeze - # \ | / | | | (idx=0) (idx=1) | - # \ | / | | | | | | - # \ | / Unsqueeze Unsqueeze Unsqueeze Unsqueeze Unsqueeze Unsqueeze - # \ | / \ / \ | / | - # Concat Concat \ | / | - # / | \ | \ | / | - # / | \ | \ | / | - # / | \ ConstantOfShape Concat | - # / | \ / \ \ / | \ | - # / Shape \ Shape Shape | / | \ | - # / | \ | | | / | \ | - # / | \ Slice Slice | / | \ | - # \ ConstantOfShape | | | | / Shape \ | - # \ | | | Squeeze Squeeze | / | \ | - # \ Mul | | | | | / | \ | - # \ | | | Unsqueeze Range | \ ConstantOfShape \ / - # \ | | | | | | | \ | | | / - # \ | | | Concat Add | | \ Mul | | / - # \ | | | | / | | \ | | | / - # Equal | / Reshape | | \ | | | / - # \ | / | | | \ | | | / - # Where Less---------+ | \ | | | / - # | | | Equal | / / - # | Where-----------+ \ | / / - # | | Where / - # | Unsqueeze | / - # | | Expand - # | Unsqueeze | - # \ / Cast - # Expand | - # | Sub - # | / | - # | / Cast - # | | | - # | Where - # | | - # +----------------------+----------------------+ - # | - # Add - # | - # Concat - + # input_ids past_key_values.0.key attention_mask--------+ + # / \ | | | | + # Shape Shape Shape Shape Shape | + # | | | | | | + # Gather Gather Gather Gather Gather | + # (idx=0) (idx=1) (idx=2) (idx=0) (idx=1) | + # | | |\ / | | | + # | | | \ / Unsqueeze Unsqueeze | + # | | | Add / | | | + # +---------------|------------+ | | / | | | + # | +------|-----------------|---Sub / | | | + # | | | | /| / | | | + # | | | ----------|-+/Add / | | | + # | | | | +----|----|-------------------------------+ | | | + # | | Unsqueeze | | | Unsqueeze | | | + # | | \ | | | / | | | + # Unsqueeze Unsqueeze \ | | +-/--------+----------+----------+ | | Unsqueeze + # \ / \ | | / | | | | | | + # Concat -----+\+ | / | | | | / | + # | | \ | / Unsqueeze Unsqueeze Unsqueeze | / Unsqueeze + # ConstantOfShape | \ | / \ / \ | / | + # | | Concat Concat \ | / | + # | | / | \ | \ | / | + # | | / | \ | \ | / | + # | | / | \ ConstantOfShape Concat | + # | | / | \ / \ \ / | \ | + # | | / Shape \ Shape Shape | / | \ | + # | | / | \ | | | / | \ | + # | | / | \ Slice Slice | / | \ | + # | | \ ConstantOfShape | | | | / Shape \ | + # | | \ | | | Squeeze Squeeze | / | \ | + # | | \ Mul | | | | | / | \ | + # | | \ | | | Unsqueeze Range | \ ConstantOfShape \ / + # | | \ | | | | | | | \ | | | / + # | | \ | | | Concat Add | | \ Mul | | / + # | | \ | | | | / | | \ | | | / + # | | Equal | / Reshape | | \ | | | / + # | | \ | / | | | \ | | | / + # | | Where Less---------+ | \ | | | / + # | | | | | Equal | / / + # Concat--------------------------|-------+Where-----------+ \ | / / + # / \ | | Where / + # Shape \ | | | / + # | \ | | Expand + # ConstantOfShape | | | + # | \ | | Cast + # Trilu +------|--+| | | + # | | | Sub + # Sub | | / | + # | | | / Cast + # Cast | | | | + # | | | Where + # Cast | | | + # | | | | + # Where-------+ | | + # | | | + # Unsqueeze | | + # | | | + # Unsqueeze +----------------------+ Expand | + # \ | + # \ | + # +-----------------------+---------------------+ + # | | + # Add + # | + # Tile + basename = "/model/attn_mask_reformat" input_ids_basename = f"{basename}/input_ids_subgraph" past_key_basename = f"{basename}/past_key_subgraph" @@ -1781,23 +1798,44 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): # # Gather Gather # (idx=1) (idx=2) - # \ / - # \ / - # \ / - # Add - # | - # Unsqueeze + # | \ / + # | \ / + # | \ / + # | Add + # | | + # | Sub + # | | + # Unsqueeze Unsqueeze + # \ / + # Concat + # | + # ConstantOfShape + # + shared_add_name = f"{basename}/Add_1" shared_add_inputs = [f"{basename}/Gather_2/output_0", f"{past_key_gather_name}/output_0"] self.make_add(shared_add_name, shared_add_inputs, dtype=TensorProto.INT64, shape=[]) + sub_name = f"{basename}/Sub" + sub_inputs = [f"{shared_add_name}/output_0", f"{basename}/Gather_2/output_0"] + self.make_sub(sub_name, sub_inputs, dtype=TensorProto.INT64, shape=["unk"]) unsqueeze_3_name = f"{basename}/Unsqueeze_3" # shared unsqueeze for input_ids and past_key_values.0.key - unsqueeze_3_inputs = [f"{shared_add_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + unsqueeze_3_inputs = [f"{sub_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] self.make_unsqueeze(unsqueeze_3_name, unsqueeze_3_inputs, dtype=TensorProto.INT64, shape=[1]) + unsqueeze_10_inputs = [f"{basename}/Gather_2/output_0", "/model/constants/TensorProto.INT64/1D/0"] + unsqueeze_10_name = f"{basename}/Unsqueeze_10" + self.make_unsqueeze(unsqueeze_10_name, unsqueeze_10_inputs, dtype=TensorProto.INT64, shape=[1]) + concat_3_name = f"{basename}/Concat_3" + concat_3_inputs = [f"{unsqueeze_10_name}/output_0", f"{unsqueeze_3_name}/output_0"] + self.make_concat(concat_3_name, concat_3_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) + constant_shape_name_1 = f"{basename}/ConstantOfShape_3" + constant_shape_numpy_dtype = self.to_numpy_dtype[self.io_dtype] + constant_shape_value = numpy_helper.from_array(np.array([0], dtype=constant_shape_numpy_dtype)) + self.make_constant_of_shape(constant_shape_name_1, f"{concat_3_name}/output_0", value=constant_shape_value, dtype=self.io_dtype, shape=['unk']) - # Make the additional subgraph for input_ids - # - # Unsqueeze (unsqueeze_4) Shape --> Slice --> Squeeze --> Unsqueeze --> Concat - # / \ / \ + # Make the additional subgraph for input_ids ConstantOfShape + # | + # Unsqueeze (unsqueeze_4) Shape --> Slice --> Squeeze --> Unsqueeze --> Concat Concat + # / \ / \ | # Gather (idx=1) --> Concat --> ConstantOfShape Reshape --> Less --> Where --> Unsqueeze --> Unsqueeze --> Expand # \ / \ | # Unsqueeze (unsqueeze_5) Shape --> Slice --> Squeeze --> Range --> Add -------+ @@ -1825,12 +1863,12 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): squeeze_1_name = f"{basename}/Squeeze_1" squeeze_1_inputs = [f"{slice_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] self.make_squeeze(squeeze_1_name, squeeze_1_inputs) - unsqueeze_7_name = f"{basename}/output_0" - unsqueeze_7_inputs = [f"{squeeze_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_unsqueeze(unsqueeze_7_name, unsqueeze_7_inputs, dtype=TensorProto.INT64, shape=[1]) - concat_3_name = f"{basename}/Concat_3" - concat_3_inputs = [f"{unsqueeze_7_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] - self.make_concat(concat_3_name, concat_3_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) + unsqueeze_8_name = f"{basename}/output_0" + unsqueeze_8_inputs = [f"{squeeze_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + self.make_unsqueeze(unsqueeze_8_name, unsqueeze_8_inputs, dtype=TensorProto.INT64, shape=[1]) + concat_5_name = f"{basename}/Concat_5" + concat_5_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] + self.make_concat(concat_5_name, concat_5_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) # Bottom path shape_5_name = f"{basename}/Shape_5" @@ -1848,24 +1886,59 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): add_inputs = [f"{range_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] self.make_add(add_2_name, add_inputs, dtype=TensorProto.INT64, shape=["unk"]) - # Merged path - reshape_name = f"{basename}/Reshape" - reshape_inputs = [f"{add_2_name}/output_0", f"{concat_3_name}/output_0"] - self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None) + # Merged Path + reshape_5_name = f"{basename}/Reshape_5" + reshape_5_inputs = [f"{add_2_name}/output_0", f"{concat_5_name}/output_0"] + self.make_reshape(reshape_5_name, reshape_5_inputs, dtype=TensorProto.INT64, shape=None) less_name = f"{basename}/Less" - less_inputs = [f"{range_name}/output_0", f"{reshape_name}/output_0"] + less_inputs = [f"{range_name}/output_0", f"{reshape_5_name}/output_0"] self.make_less(less_name, less_inputs) where_2_name = f"{basename}/Where_2" where_2_inputs = [f"{less_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/0", f"{constant_shape_name}/output_0"] self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=None) + concat_4_name = f"{basename}/Concat_4" + concat_inputs = [f"{constant_shape_name_1}/output_0", f"{where_2_name}/output_0"] + self.make_concat(concat_4_name, concat_inputs, dtype=self.io_dtype, shape=[2], axis=-1) + + sub_1_name = f"{basename}/Sub_1" + sub_1_inputs = [f"{sub_name}/output_0", "/model/constants/TensorProto.INT64/0D/2047"] + self.make_sub(sub_1_name, sub_1_inputs, dtype=TensorProto.INT64, shape=["unk"]) + add_3_name = f"{basename}/Add_3" + add_3_inputs = [f"{sub_1_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] + self.make_add(add_3_name, add_3_inputs, dtype=TensorProto.INT64, shape=["unk"]) + shape_6_name = f"{basename}/Shape_6" + self.make_shape(shape_6_name, f"{concat_4_name}/output_0", shape=[2]) + constant_shape_name_2 = f"{basename}/ConstantOfShape_4" + constant_shape_value = numpy_helper.from_array(np.array([1], dtype="int64")) + self.make_constant_of_shape(constant_shape_name_2, f"{shape_6_name}/output_0", value=constant_shape_value, dtype=TensorProto.INT64, shape=['unk', 'unk']) + trilu_name = f"{basename}/Trilu_1" + trilu_inputs = [f"{constant_shape_name_2}/output_0", f"{add_3_name}/output_0"] + self.make_trilu(trilu_name, trilu_inputs, dtype=TensorProto.INT64, shape=["unk"], upper=1) + sub_2_name = f"{basename}/Sub_2" + sub_2_inputs = [f"{trilu_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] + self.make_sub(sub_2_name, sub_2_inputs, dtype=TensorProto.INT64, shape=["unk"]) + cast_3_name = f"{basename}/Cast_3" + self.make_cast(cast_3_name, f"{sub_2_name}/output_0", dtype=TensorProto.BOOL, shape=["unk", "unk", "unk", "unk"]) + cast_4_name = f"{basename}/Cast_4" + self.make_cast(cast_4_name, f"{cast_3_name}/output_0", dtype=TensorProto.BOOL, shape=["unk"]) + add_4_name = f"{basename}/Add_4" + add_4_inputs = [f"{basename}/Gather_2/output_0", f"{sub_name}/output_0"] + self.make_add(add_4_name, add_4_inputs, dtype=TensorProto.INT64, shape=["unk"]) + unsqueeze_11_name = f"{basename}/Unsqueeze_11" + unsqueeze_11_inputs = [f"{add_4_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + self.make_unsqueeze(unsqueeze_11_name, unsqueeze_11_inputs, dtype=TensorProto.INT64, shape=[1]) + + where_3_name = f"{basename}/Where_3" + where_3_inputs = [f"{cast_4_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.finfo(self.to_numpy_dtype[self.io_dtype]).min}", f"{concat_4_name}/output_0"] + self.make_where(where_3_name, where_3_inputs, dtype=self.io_dtype, shape=None) unsqueeze_8_name = f"{basename}/Unsqueeze_8" - unsqueeze_8_inputs = [f"{where_2_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + unsqueeze_8_inputs = [f"{where_3_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] self.make_unsqueeze(unsqueeze_8_name, unsqueeze_8_inputs, dtype=self.io_dtype, shape=None) unsqueeze_9_name = f"{basename}/Unsqueeze_9" unsqueeze_9_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] self.make_unsqueeze(unsqueeze_9_name, unsqueeze_9_inputs, dtype=self.io_dtype, shape=None) - expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", unsqueeze_for_concat=unsqueeze_3_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) + expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", unsqueeze_for_concat=unsqueeze_11_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) return unsqueeze_6_name, expand_name def make_attention_mask_subgraph(self, basename, unsqueeze_for_concat): @@ -1896,8 +1969,10 @@ def make_attention_mask_subgraph(self, basename, unsqueeze_for_concat): self.make_sub(sub_name, sub_inputs, dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) cast_2_name = f"{basename}/Cast_2" self.make_cast(cast_2_name, f"{sub_name}/output_0", dtype=TensorProto.BOOL, shape=["unk", "unk", "unk", "unk"]) + cast_5_name = f"{basename}/Cast_5" + self.make_cast(cast_5_name, f"{cast_2_name}/output_0", dtype=TensorProto.BOOL, shape=["unk"]) where_2_name = f"{basename}/Where_2" - where_2_inputs = [f"{cast_2_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.finfo(self.to_numpy_dtype[self.io_dtype]).min}", f"{sub_name}/output_0"] + where_2_inputs = [f"{cast_5_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/{np.finfo(self.to_numpy_dtype[self.io_dtype]).min}", f"{sub_name}/output_0"] self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=["unk", "unk", "unk", "unk"]) return where_2_name @@ -1939,17 +2014,17 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for # \ / # Expand - shape_1_name = f"{basename}/Shape_1" - self.make_shape(shape_1_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) + shape_1_name = f"/model/attn_mask_reformat/attn_mask_subgraph/Shape_1" + self.make_shape(shape_1_name, "attention_mask", shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) shape_2_name = f"{basename}/Shape_2" self.make_shape(shape_2_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) - gather_1_name = f"{basename}/Gather_1" + gather_1_name = f"/model/attn_mask_reformat/attn_mask_subgraph/Gather_1" gather_1_inputs = [f"{shape_1_name}/output_0", "/model/constants/TensorProto.INT64/0D/0"] self.make_gather(gather_1_name, gather_1_inputs, axis=0) gather_2_name = f"{basename}/Gather_2" gather_2_inputs = [f"{shape_2_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] self.make_gather(gather_2_name, gather_2_inputs, axis=0) - unsqueeze_1_name = f"{basename}/Unsqueeze_1" + unsqueeze_1_name = f"/model/attn_mask_reformat/attn_mask_subgraph/Unsqueeze_1" unsqueeze_1_inputs = [f"{gather_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] self.make_unsqueeze(unsqueeze_1_name, unsqueeze_1_inputs, dtype=TensorProto.INT64, shape=[1]) unsqueeze_2_name = f"{basename}/Unsqueeze_2" @@ -1961,8 +2036,11 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for concat_last_two_inputs = [f"{unsqueeze_for_concat}/output_0", f"{unsqueeze_2_name}/output_0"] if not input_ids_subgraph else [f"{unsqueeze_2_name}/output_0", f"{unsqueeze_for_concat}/output_0"] concat_inputs = concat_first_two_inputs + concat_last_two_inputs self.make_concat(concat_name, concat_inputs, dtype=TensorProto.INT64, shape=[4], axis=0) + reshape_name = f"{basename}/Reshape" + reshape_inputs = [f"{concat_name}/output_0", "/model/constants/TensorProto.INT64/1D/-1"] + self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=[1]) shape_3_name = f"{basename}/Shape_3" - self.make_shape(shape_3_name, f"{concat_name}/output_0", shape=[1]) + self.make_shape(shape_3_name, f"{reshape_name}/output_0", shape=[1]) constant_shape_name = f"{basename}/ConstantOfShape" if not input_ids_subgraph else f"{basename}/ConstantOfShape_1" constant_shape_value = numpy_helper.from_array(np.array([1], dtype="int64")) self.make_constant_of_shape(constant_shape_name, f"{shape_3_name}/output_0", value=constant_shape_value, dtype=TensorProto.INT64, shape=["unk"]) @@ -1970,7 +2048,7 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for mul_inputs = [f"{constant_shape_name}/output_0", "/model/constants/TensorProto.INT64/0D/-1"] self.make_mul(mul_name, mul_inputs, dtype=TensorProto.INT64, shape=["unk"]) equal_name = f"{basename}/Equal" - equal_inputs = [f"{concat_name}/output_0", f"{mul_name}/output_0"] + equal_inputs = [f"{reshape_name}/output_0", f"{mul_name}/output_0"] self.make_equal(equal_name, equal_inputs, shape=[4]) where_name = f"{basename}/Where_1" @@ -1984,6 +2062,7 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for return expand_name + def make_attention_mask_reformatting_for_gqa(self): # Make nodes for the attention mask subgraph that calculates # attributes about the 2D attention mask to use in GroupQueryAttention @@ -2109,12 +2188,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): def make_attention(self, layer_id, attention, root_input, **kwargs): super().make_attention(layer_id, attention, root_input, position_ids=self.position_ids_name, **kwargs) - class QwenModel(MistralModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - class PhiModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)