From 218b6b0a734f6233845313d2c61531b7558e0b5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 25 Apr 2024 00:46:48 +0200 Subject: [PATCH] Fix missing argument when calling _get_quantize_input_nodes (#20245) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description The current code is calling one method with a missing argument. ### Motivation and Context It breaks Olive's unittests. --------- Co-authored-by: Xavier Dupré --- .../tools/quantization/onnx_quantizer.py | 38 +++++++++++---- .../tools/quantization/operators/pad.py | 1 + .../test/python/quantization/test_op_gemm.py | 47 ++++++++++++++++++- 3 files changed, 77 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index f84e00abd6105..e1e4a4f724fdc 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -306,20 +306,19 @@ def is_float_tensor(self, tensor_name): ) return False - def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType): + def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType, initial_type): """ Create nodes for dynamic quantization of input and add them to nodes_list. parameter input_name: Name of the input. parameter nodes_list: new nodes are appended to this list. parameter qType: type to quantize to. + parameter initial_type: type to quantize from return: scale_name, zero_point_name, scale_shape, zero_point_shape. """ if qType == onnx_proto.TensorProto.INT8: - return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list) + return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list, initial_type) if qType == onnx_proto.TensorProto.UINT8: - return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list) - if qType == onnx_proto.TensorProto.FLOAT8E4M3FN: - return self._get_dynamic_input_quantization_params_float8e4m3fn(input_name, nodes_list) + return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list, initial_type) raise ValueError(f"Unexpected value for qType={qType}.") def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list, initial_type): @@ -559,7 +558,9 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non return True, scale_name, zero_point_name, scale_shape, zero_point_shape - def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None): + def _get_quantize_input_nodes( + self, node, input_index, qType, given_scale_name=None, given_zp_name=None, initial_type=None + ): """ Given an input for a node (which is not a initializer), this function @@ -571,6 +572,7 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N :param qType: type to quantize to. :param given_scale_name: if those inputs need to be quanitzed using this scale tensor. :param given_zp_name: if those inputs to be quantized using this zeropoint tensor. + :param initial_type: type of the weight to quantize :return: List of newly created nodes in NodeProto format. """ input_name = node.input[input_index] @@ -606,12 +608,16 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N ql_node_name, ) else: + assert initial_type is not None, ( + f"Cannot quantize input without knowing the initial type, " + f"input_name={input_name!r}, input_index={input_index}, qType={qType}, node={node}" + ) ( scale_name, zp_name, scale_shape, zp_shape, - ) = self._get_dynamic_input_quantization_params(input_name, nodes, qType) + ) = self._get_dynamic_input_quantization_params(input_name, nodes, qType, initial_type=initial_type) qlinear_node = onnx.helper.make_node( "QuantizeLinear", [input_name, scale_name, zp_name], @@ -794,7 +800,23 @@ def __quantize_inputs( node_input + "_QuantizeLinear", self.new_nodes, self.model.graph() ) if qlinear_node is None: - quantize_input_nodes = self._get_quantize_input_nodes(node, input_index, self.activation_qType) + input_name = node.input[input_index] + if input_name in self.value_infos: + value_info = self.value_infos[input_name] + assert value_info.HasField("type"), f"value_info={value_info} has no type." + assert value_info.type.HasField("tensor_type"), f"value_info={value_info} is not a tensor." + initial_type = value_info.type.tensor_type.elem_type + else: + # Shape inference failed. Fallback to self.tensor_names. + assert input_name in self.tensor_names, ( + f"shape inference failed for {input_name!r} and " + f"attribute 'tensor_names' does not have any value for " + f"this tensor." + ) + initial_type = self.tensor_names[input_name] + quantize_input_nodes = self._get_quantize_input_nodes( + node, input_index, self.activation_qType, initial_type=initial_type + ) if quantize_input_nodes is None: return (None, None, None, None) if from_subgraph: diff --git a/onnxruntime/python/tools/quantization/operators/pad.py b/onnxruntime/python/tools/quantization/operators/pad.py index 25818de1b76bd..5f3c1231e62d6 100644 --- a/onnxruntime/python/tools/quantization/operators/pad.py +++ b/onnxruntime/python/tools/quantization/operators/pad.py @@ -68,6 +68,7 @@ def quantize(self): self.quantizer.activation_qType, quantized_input_value.scale_name, quantized_input_value.zp_name, + initial_type=scale_tensor.data_type, ) self.quantizer.new_nodes.extend(pad_value_qnodes) node.input[2] = pad_value_qnodes[0].output[0] diff --git a/onnxruntime/test/python/quantization/test_op_gemm.py b/onnxruntime/test/python/quantization/test_op_gemm.py index 843b34a6398b3..96a7ab6b9d9f3 100644 --- a/onnxruntime/test/python/quantization/test_op_gemm.py +++ b/onnxruntime/test/python/quantization/test_op_gemm.py @@ -784,7 +784,52 @@ def test_qgemm_ref_uint8_specific_example(self): got = ref.run(None, feeds)[0] assert_allclose(expected, got) + def test_dynamic_quantization(self): + # dummy_model.onnx from Olive + model = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "Gemm", ["input", "fc1.weight", "fc1.bias"], ["gemm0"], alpha=1.0, beta=1.0, transB=1 + ), + helper.make_node("Relu", ["gemm0"], ["output"]), + ], + "g", + [helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1])], + [helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 10])], + [ + onnx.numpy_helper.from_array(np.random.randn(10, 1).astype(np.float32), name="fc1.weight"), + onnx.numpy_helper.from_array(np.random.randn(10).astype(np.float32), name="fc1.bias"), + ], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + onnx.checker.check_model(model) + run_config = { + "weight_type": QuantType.QInt8, + "op_types_to_quantize": None, + "nodes_to_quantize": None, + "nodes_to_exclude": None, + "per_channel": False, + "reduce_range": False, + "extra_options": { + "extra.Sigmoid.nnapi": False, + "ActivationSymmetric": False, + "WeightSymmetric": True, + "EnableSubgraph": False, + "ForceQuantizeNoInputCheck": False, + "MatMulConstBOnly": True, + }, + } + model_path = "test_dynamic_quantization.onnx" + with open(model_path, "wb") as f: + f.write(model.SerializeToString()) + qpath = "test_dynamic_quantization.quantized.onnx" + quantize_dynamic(model_input=model_path, model_output=qpath, use_external_data_format=True, **run_config) + onx = onnx.load(qpath) + self.assertIn("DynamicQuantizeLinear", set(n.op_type for n in onx.graph.node)) + if __name__ == "__main__": - TestOpGemm().test_quantize_gemm_e4m3fn_p3() unittest.main(verbosity=2)