diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc index 925930c870188..2448bfc766306 100644 --- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc +++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc @@ -67,8 +67,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { Expr final_call = post; call = post.as(); - // Create a new variable argument that is of the same shape as the neighbouring argument - // in the binary op. This needs to be done only when one of the arguments is a scalar. + // Substitute scalar variable with a tensor variable. if (call->op.as()) { final_call = ReplaceScalarWithTensorVariable(GetRef(call)); } @@ -86,63 +85,78 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { final_call = Call(global_var, call->args); } - // Substitute scalar constant with a tensor constant in the call to composite function - // comprising partitioned binary ops. Shape of the new constant should be same as its - // neighbouring tensor's shape. + // Substitute scalar constant with tensor constant in the call to composite function. if (auto* func_node = call->op.as()) { Function func = GetRef(func_node); + final_call = ReplaceScalarWithTensorConstant(GetRef(call), func); + } + + return final_call; + } + + // Checks if expr can undergo scalar to tensor replacement + bool WorthyOfScalarToTensorReplacement(const Expr& expr) { + if (const CallNode* call = expr.as()) { + if (const OpNode* opnode = call->op.as()) { + if (opnode->name == "qnn.add" || opnode->name == "qnn.mul") { + return true; + } + } + } + if (const FunctionNode* func = expr.as()) { auto func_name = func->GetAttr(attr::kComposite); if (func_name.defined() && (func_name == "cmsis-nn.qnn_add" || func_name == "cmsis-nn.qnn_mul")) { - final_call = ReplaceScalarWithTensorConstant(GetRef(call), func); + return true; } } - - return final_call; + return false; } - // Replaces scalar variable with a tensor variable with same shape as that of the neibouring - // operand tensor in a binary op + // Replaces scalar variable with a tensor variable with same shape as that of the neighbouring + // operand tensor in a binary op (add or multiply supported via CMSIS-NN path). This applies only + // to 1st and 2nd arguments of the ops. Call ReplaceScalarWithTensorVariable(Call call) { - const OpNode* opnode = call->op.as(); - if (opnode == nullptr) { + if (!WorthyOfScalarToTensorReplacement(call)) { return call; } - String op_name = opnode->name; - Array new_args; - for (uint32_t i = 0; i < call->args.size(); ++i) { - Expr arg = call->args[i]; - new_args.push_back(arg); - if (!arg->checked_type_.defined()) { + Array new_args(call->args); + for (uint32_t i = 0; i < 2; ++i) { + Expr scalar_arg = call->args[i]; + if (!scalar_arg->IsInstance() || !scalar_arg->checked_type_.defined() || + !scalar_arg->checked_type_->IsInstance()) { continue; } - auto* arg_type = arg->type_as(); - if (arg_type->shape.size() != 0 || arg.as()) { + Array scalar_shape = scalar_arg->type_as()->shape; + if (scalar_shape.size() != 0) { continue; } - String arg_name = arg.as()->name_hint(); int tensor_arg_id = (i + 1) % 2; Expr tensor_arg = call->args[tensor_arg_id]; if (!tensor_arg->checked_type_.defined()) { continue; } - TensorType tensor_type = GetRef(tensor_arg->type_as()); - new_args.Set(i, Var(arg_name, tensor_type)); + String arg_name = scalar_arg.as()->name_hint(); + new_args.Set(i, Var(arg_name, tensor_arg->checked_type_)); } return Call(call->op, new_args, call->attrs, {}); } - // Makes tensor constant of same shape as tensor_arg with values from scalar_arg + // Replaces scalar constant with a tensor constant with same shape as that of the neighbouring + // operand tensor in a binary op (add or multiply supported via CMSIS-NN path). This applies only + // to 1st and 2nd arguments of the ops. Call ReplaceScalarWithTensorConstant(Call call, Function func) { - Array new_args; - for (uint32_t i = 0; i < call->args.size(); ++i) { - new_args.push_back(call->args[i]); + if (!WorthyOfScalarToTensorReplacement(func)) { + return call; + } + Array new_args(call->args); + for (uint32_t i = 0; i < 2; ++i) { Expr scalar_arg = call->args[i]; if (!scalar_arg->checked_type_.defined()) { continue; } Array scalar_shape = scalar_arg->type_as()->shape; - if (scalar_shape.size() != 0 || scalar_arg.as() == nullptr) { + if (scalar_shape.size() != 0 || !scalar_arg->IsInstance()) { continue; } int tensor_arg_id = (i + 1) % 2; diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py index 223a2b65e9346..9c665053e2cf4 100644 --- a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py +++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py @@ -26,6 +26,34 @@ tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) +def generate_variable(name, shape, dtype="int8"): + return relay.var(name, shape=shape, dtype=dtype) + + +def make_binary_op( + op, + input_0, + input_1, + input_0_scale, + input_0_zero_point, + input_1_scale, + input_1_zero_point, + out_scale=1.0 / 256, + out_zero_point=-128, +): + """Create a Relay Function / network model""" + return op( + input_0, + input_1, + relay.const(input_0_scale, "float32"), + relay.const(input_0_zero_point, "int32"), + relay.const(input_1_scale, "float32"), + relay.const(input_1_zero_point, "int32"), + relay.const(out_scale, "float32"), + relay.const(out_zero_point, "int32"), + ) + + class CheckFunctionsForConstants(tvm.relay.ExprVisitor): def __init__(self): super().__init__() @@ -55,22 +83,33 @@ def set_composite_func_attr(func, name): @tvm.testing.requires_cmsisnn def test_single_scalar_position_0(): - x0 = relay.var("x0", shape=None) - x1 = relay.var("x1", shape=(8, 8)) - z1 = x0 + x1 - lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32")) + dtype = "int8" + shape = (8, 8) + x0 = generate_variable("x0", None, dtype) + x1 = generate_variable("x1", shape, dtype) + z1 = make_binary_op( + relay.qnn.op.add, + x0, + x1, + input_0_scale=0.0128, + input_0_zero_point=32, + input_1_scale=0.256, + input_1_zero_point=-64, + ) + + lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") - y0 = relay.expr.const(3, "float32") - y1 = relay.var("y1", shape=(8, 8)) + y0 = relay.expr.const(3, dtype) + y1 = relay.var("y1", shape=shape, dtype=dtype) c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([y1], c0, relay.TensorType((8, 8), "float32")) + ef = relay.Function([y1], c0, relay.TensorType(shape, dtype)) - x = relay.var("x", shape=(8, 8)) + x = relay.var("x", shape=shape, dtype=dtype) ev = relay.GlobalVar("external_function") ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) c = relay.Call(ev, [x]) - mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mf = relay.Function([x], c, relay.TensorType(shape, dtype)) mv = relay.GlobalVar("main") mod = tvm.IRModule() @@ -79,6 +118,7 @@ def test_single_scalar_position_0(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) + mod = relay.transform.InferType()(mod) check_for_constants = CheckFunctionsForConstants() check_for_constants.visit_call(mod[ev].body) assert ( @@ -88,22 +128,33 @@ def test_single_scalar_position_0(): @tvm.testing.requires_cmsisnn def test_single_scalar_position_1(): - x0 = relay.var("x0", shape=(8, 8)) - x1 = relay.var("x1", shape=None) - z1 = x0 + x1 - lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32")) + dtype = "int8" + shape = (8, 8) + x0 = generate_variable("x0", shape, dtype) + x1 = generate_variable("x1", None, dtype) + z1 = make_binary_op( + relay.qnn.op.add, + x0, + x1, + input_0_scale=0.0128, + input_0_zero_point=32, + input_1_scale=0.256, + input_1_zero_point=-64, + ) + + lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") - y0 = relay.var("y0", shape=(8, 8)) - y1 = relay.expr.const(3, "float32") + y0 = relay.var("y0", shape=shape, dtype=dtype) + y1 = relay.expr.const(3, dtype) c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([y0], c0, relay.TensorType((8, 8), "float32")) + ef = relay.Function([y0], c0, relay.TensorType(shape, dtype)) - x = relay.var("x", shape=(8, 8)) + x = relay.var("x", shape=shape, dtype=dtype) ev = relay.GlobalVar("external_function") ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) c = relay.Call(ev, [x]) - mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mf = relay.Function([x], c, relay.TensorType(shape, dtype)) mv = relay.GlobalVar("main") mod = tvm.IRModule() @@ -112,6 +163,7 @@ def test_single_scalar_position_1(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) + mod = relay.transform.InferType()(mod) check_for_constants = CheckFunctionsForConstants() check_for_constants.visit_call(mod[ev].body) assert ( @@ -120,22 +172,33 @@ def test_single_scalar_position_1(): @tvm.testing.requires_cmsisnn -def test_two_scalars(): - x1 = relay.var("x1", shape=None) - x2 = relay.var("x2", shape=None) - z1 = x1 + x2 - lf = relay.Function([x1, x2], z1, relay.TensorType((), "float32")) +def test_primary_operands_all_scalars(): + dtype = "int8" + shape = None + x0 = generate_variable("x0", None, dtype) + x1 = generate_variable("x1", None, dtype) + z1 = make_binary_op( + relay.qnn.op.add, + x0, + x1, + input_0_scale=0.0128, + input_0_zero_point=32, + input_1_scale=0.256, + input_1_zero_point=-64, + ) + + lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") - y0 = relay.expr.const(5, "float32") - y1 = relay.expr.const(3, "float32") + y0 = relay.expr.const(7, dtype) + y1 = relay.expr.const(3, dtype) c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([], c0, relay.TensorType((), "float32")) + ef = relay.Function([], c0, relay.TensorType(shape, dtype)) ev = relay.GlobalVar("external_function") ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) c = relay.Call(ev, []) - mf = relay.Function([], c, relay.TensorType((), "float32")) + mf = relay.Function([], c, relay.TensorType(shape, dtype)) mv = relay.GlobalVar("main") mod = tvm.IRModule() @@ -144,30 +207,39 @@ def test_two_scalars(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) - check_for_constants = CheckFunctionsForConstants() - check_for_constants.visit_call(mod[ev].body) - assert ( - check_for_constants.num_constants_ == 0 - ), "Scalar constant wasn't converted into tensor constant" + new_mod = relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(mod[ev].body, new_mod[ev].body) @tvm.testing.requires_cmsisnn -def test_two_tensor_constants(): - x0 = relay.var("x0", shape=(8, 8)) - x1 = relay.var("x1", shape=(8, 8)) - z1 = x0 + x1 - lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32")) +def test_all_primary_operands_tensor_constants(): + dtype = "int8" + shape = (1, 3, 3, 32) + x0 = generate_variable("x0", shape, dtype) + x1 = generate_variable("x1", shape, dtype) + z1 = make_binary_op( + relay.qnn.op.add, + x0, + x1, + input_0_scale=0.0128, + input_0_zero_point=32, + input_1_scale=0.256, + input_1_zero_point=-64, + ) + + lf = relay.Function([x0, x1], z1, relay.TensorType(shape, dtype)) lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") - y0 = relay.const(np.random.uniform(0, 1, (8, 8)).astype("float32"), "float32") - y1 = relay.const(np.random.uniform(0, 1, (8, 8)).astype("float32"), "float32") + rng = np.random.default_rng(12345) + y0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) + y1 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype)) c0 = relay.Call(lf, [y0, y1]) - ef = relay.Function([], c0, relay.TensorType((8, 8), "float32")) + ef = relay.Function([], c0, relay.TensorType(shape, dtype)) ev = relay.GlobalVar("external_function") ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) c = relay.Call(ev, []) - mf = relay.Function([], c, relay.TensorType((8, 8), "float32")) + mf = relay.Function([], c, relay.TensorType(shape, dtype)) mv = relay.GlobalVar("main") mod = tvm.IRModule() @@ -176,11 +248,8 @@ def test_two_tensor_constants(): mod = relay.transform.InferType()(mod) mod = ScalarToTensorConstants()(mod) - check_for_constants = CheckFunctionsForConstants() - check_for_constants.visit_call(mod[ev].body) - assert ( - check_for_constants.num_constants_ == 2 - ), "Scalar constant wasn't converted into tensor constant" + new_mod = relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(mod[ev].body, new_mod[ev].body) @tvm.testing.requires_cmsisnn