diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 09831929e527..8d714b7269d9 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -223,7 +223,6 @@ def qnn_max_pool2d_pattern(): def check_qnn_max_pool2d(pattern): """Check if max pool2d is supported by CMSIS-NN.""" output = pattern - input_op = None if str(pattern.op.name) == "clip": pooling = pattern.args[0] diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 1cbe36e30f76..c6ed7af9ff03 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -164,7 +164,18 @@ class ExtractConstantsMutator : public MixedModeMutator { function_signature.push_back(arg); } else { if (arg.as()) { - function_signature.push_back(arg); + // Only push if its not already present as multiple consumers of any input var + // will appear only once in the function signature. + bool found_in_existing_signature = false; + for (auto& sign : function_signature) { + if (arg.same_as(sign)) { + found_in_existing_signature = true; + break; + } + } + if (!found_in_existing_signature) { + function_signature.push_back(arg); + } } new_args.push_back(arg); } diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 524735caa9d6..5c99061fa854 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -556,7 +556,12 @@ class RelayToTIRVisitor : public MixedModeMutator { BufferCreator buffer_creator; tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8)); - tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + tir::Var input_1; + if (mul_call->args[0].same_as(mul_call->args[1])) { + input_1 = input_0; + } else { + input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + } tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { @@ -626,7 +631,12 @@ class RelayToTIRVisitor : public MixedModeMutator { BufferCreator buffer_creator; tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8)); - tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + tir::Var input_1; + if (add_call->args[0].same_as(add_call->args[1])) { + input_1 = input_0; + } else { + input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8)); + } tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8)); tvm::Array args = { diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc index 2448bfc76630..40fd773eb209 100644 --- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc +++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc @@ -179,6 +179,12 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { auto new_body = VisitExpr(func->body); Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); + + // Updating new_func parameters could result into uniquification of function parameters. + // Call arguments need to be aligned to the number of arguments expected by new_func. + if (new_args[0].same_as(new_args[1])) { + new_args.erase(new_args.begin()); + } return Call(new_func, new_args); } diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index fec18c197e04..26604da0a64a 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -101,7 +101,7 @@ def make_model( def test_op_int8( op, relu_type, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point ): - """Tests QNN Conv2D operator for CMSIS-NN""" + """Tests QNN binary operator for CMSIS-NN""" interface_api = "c" use_unpacked_api = True test_runner = AOT_USMP_CORSTONE300_RUNNER @@ -145,6 +145,65 @@ def test_op_int8( ) +@skip_if_no_reference_system +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) +@pytest.mark.parametrize("relu_type", ["RELU", "NONE"]) +def test_same_input_to_binary_op(op, relu_type): + """Tests QNN binary operator for CMSIS-NN where both inputs are the same""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + dtype = "int8" + shape = [1, 16, 16, 3] + input_ = generate_variable("input") + input_scale = 0.256 + input_zero_point = 33 + + model = make_model( + op, + input_, + input_, + input_scale, + input_zero_point, + input_scale, + input_zero_point, + relu_type, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # Check if the number of internal function parameter is 1 + cmsisnn_global_func = cmsisnn_mod["tvmgen_default_cmsis_nn_main_0"] + assert ( + isinstance(cmsisnn_global_func.body, tvm.relay.expr.Call) + and len(cmsisnn_global_func.body.args) == 1 + ), "Composite function for the binary op should have only 1 parameter." + + # validate the output + in_min, in_max = get_range_for_dtype_str(dtype) + inputs = { + "input": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype), + } + output_list = generate_ref_data(orig_mod["main"], inputs) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + def parameterize_for_constant_inputs(test): """Generates parameters in such a way so that at least one of the inputs is a constant, both can't be variables, both can't be scalars. diff --git a/tests/python/contrib/test_cmsisnn/test_extract_constants.py b/tests/python/contrib/test_cmsisnn/test_extract_constants.py index 8831596d40e6..7d3e81a9c79d 100644 --- a/tests/python/contrib/test_cmsisnn/test_extract_constants.py +++ b/tests/python/contrib/test_cmsisnn/test_extract_constants.py @@ -116,6 +116,40 @@ def test_nested_function(): relay.transform.InferType()(mod) +@tvm.testing.requires_cmsisnn +def test_internal_function_with_duplicate_arguments(): + """Tests the pass ExternConstants when a composite function + is present within global function with repeating arguments + to one of the binary ops. + """ + input0 = relay.var("input0", shape=(8, 8)) + binary_op0 = input0 + input0 + binary_op1 = binary_op0 * relay.const(5.0, "float32") + local_func = relay.Function([input0], binary_op1, relay.TensorType((8, 8), "float32")) + local_func = set_composite_func_attr(local_func, "cmsis-nn") + + arg = relay.var("arg", shape=(8, 8)) + call_local_func = relay.Call(local_func, [arg]) + extern_func = relay.Function([arg], call_local_func, relay.TensorType((8, 8), "float32")) + + global_arg = relay.var("global_var", shape=(8, 8)) + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, [global_arg]) + main_func = relay.Function([global_arg], call_extern_func, relay.TensorType((8, 8), "float32")) + main_var = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[global_var] = extern_func + mod[main_var] = main_func + + mod = ExtractConstantsFromPartitionedFunction()(mod) + constant_verifier = CheckFunctionsForConstants() + constant_verifier.visit_function(mod[global_var]) + constant_verifier.check_num_constants() + relay.transform.InferType()(mod) + + @tvm.testing.requires_cmsisnn def test_multiple_functions(): """Tests the pass ExternConstants when global function diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py index 557a65aeffca..df54f7ce55f1 100644 --- a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py +++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py @@ -256,6 +256,47 @@ def test_all_primary_operands_tensor_constants(): assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) +@tvm.testing.requires_cmsisnn +def test_duplicate_constant_arguments(): + """Tests the pass when repeating operands are arguments to the binary op""" + dtype = "int8" + shape = (1, 3, 3, 32) + operand0 = generate_variable("operand0", shape, dtype) + operand1 = generate_variable("operand1", shape, dtype) + binary_op = make_binary_op( + relay.qnn.op.add, + operand0, + operand0, + input_0_scale=0.0128, + input_0_zero_point=32, + input_1_scale=0.256, + input_1_zero_point=-64, + ) + + local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype)) + local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add") + + rng = np.random.default_rng(12345) + arg0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) + call_local_func = relay.Call(local_func, [arg0, arg0]) + extern_func = relay.Function([], call_local_func, relay.TensorType(shape, dtype)) + + global_var = relay.GlobalVar("external_function") + extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint) + call_extern_func = relay.Call(global_var, []) + main_func = relay.Function([], call_extern_func, relay.TensorType(shape, dtype)) + main_var = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[global_var] = extern_func + mod[main_var] = main_func + + mod = relay.transform.InferType()(mod) + mod = ScalarToTensorConstants()(mod) + new_mod = relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body) + + @tvm.testing.requires_cmsisnn def test_non_cmsisnn_ext_func(): """Non CMSISNN functions should not be altered."""