From 283a3db675ae0b668e42c524eaf76ac93409e8b3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 4 Mar 2023 18:33:48 -0500 Subject: [PATCH] [Unity][Op] Group normalization (#14194) * [TOPI] Group normalization As more and more ML models nowadays contain the group normalization computation, we find it beneficial to introduce this op to TOPI level. It will enable us to optimize the group normalization operation as a whole in a more convenient way. This PR introduces the group normalization op to TOPI. The group norm operation was introduced in https://arxiv.org/abs/1803.08494. The implementation uses tuple reduction, same as the implementation of layer norm. Implemented with tuple reduction, the corresponding generated TIR function can be optimized by cross-thread reduction or rfactor through MetaSchedule. Prior to this PR, the group normalization operations in frontend models are translated to a series of operations, which brings inconvenience when we want to optimize the group norm op as a whole. With the TOPI implementation of group norm being introduced by #14193, we can now use it to legalize the high-level group norm op and optimize it using cross-thread reduction or rfactor via MetaSchedule. Co-authored-by: Bohan Hou --- include/tvm/relax/attrs/nn.h | 21 ++ .../tvm/relax/frontend/torch/fx_translator.py | 54 ++-- python/tvm/relax/op/nn/nn.py | 58 +++++ python/tvm/relax/transform/legalize_ops/nn.py | 14 ++ src/relax/op/nn/nn.cc | 83 ++++++ src/relax/op/nn/nn.h | 4 + tests/python/relax/test_ast_printer.py | 4 +- tests/python/relax/test_frontend_from_fx.py | 32 +-- tests/python/relax/test_op_nn.py | 238 ++++++++++++++++++ .../relax/test_transform_legalize_ops_nn.py | 162 ++++++++++++ .../relax/test_tvmscript_parser_op_nn.py | 25 ++ 11 files changed, 638 insertions(+), 57 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 694a51070683..61b1622a6082 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode { } }; // struct LayerNormAttrs +/*! \brief Attributes used in group_norm operator */ +struct GroupNormAttrs : public tvm::AttrsNode { + int num_groups; + int channel_axis; + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") { + TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into."); + TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); + TVM_ATTR_FIELD(axes).describe( + "The axes that along which the normalization is applied (excluding the channel axis)."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct GroupNormAttrs + /*! \brief Attributes used in dropout operator */ struct DropoutAttrs : public tvm::AttrsNode { double rate; diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e80f73096c59..24fcf0caca64 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: ) def _group_norm(self, node: fx.node.Node) -> relax.Var: - # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, - # affine=True, device=None, dtype=None) + import torch # type: ignore + x = self.env[node.args[0]] module = self.named_modules[node.target] - num_groups = module.num_groups - num_channels = module.num_channels - eps = module.eps - affine = module.affine - shape = self.shape_of(x) - assert len(shape) == 4 - N, C, H, W = shape[0], shape[1], shape[2], shape[3] - assert C == num_channels - assert C % num_groups == 0 - grouped_x = self.block_builder.emit( - relax.op.reshape(x, [N, num_groups, C // num_groups, H, W]) - ) - mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True)) - sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x)) - square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x)) - sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) - var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) - var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) - std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps)) - norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x)) - - if affine: - weight = self.params[module.weight] - bias = self.params[module.bias] - weight_reshape = self.block_builder.emit( - relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1)) - ) - bias_reshape = self.block_builder.emit( - relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1)) + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=module.num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=module.eps, ) - norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape)) - norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape)) - return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W))) + ) def _embedding(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 2fef37249703..bbb1268f1c96 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -527,6 +527,64 @@ def layer_norm( return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore +def group_norm( + data: Expr, + gamma: Expr, + beta: Expr, + num_groups: int, + channel_axis: int, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Group normalization (Yuxin Wu and et al., 2016). + Applies group normalization to the n-dimensional input array. + This operator takes an n-dimensional input array. First separate the input array + into groups along the channel axis. Then apply layer normalization to each group. + + Parameters + ---------- + data : relax.Expr + Input to which group_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + num_groups : int + Number of groups to separate the channels into. + + channel_axis : int + The index of the channel axis in the input data. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied (excluding the group axis) + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.group_norm( # type: ignore + data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale + ) + + def dropout(data: Expr, rate: float = 0.5) -> Expr: """Applies the dropout operation to the input tensor. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 70bb2513dda3..a61e0cd09ee1 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.group_norm") +def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.group_norm, + call.args[0], + call.args[1], + call.args[2], + call.attrs.num_groups, + call.attrs.channel_axis, + call.attrs.axes, + call.attrs.epsilon, + ) + + @register_legalize("relax.nn.dropout") def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index e63b3306f25d..430d2268cec3 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -233,6 +233,89 @@ TVM_REGISTER_OP("relax.nn.layer_norm") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm); +/* relax.nn.group_norm */ +TVM_REGISTER_NODE_TYPE(GroupNormAttrs); + +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->num_groups = num_groups; + attrs->channel_axis = channel_axis; + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.group_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); + +StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + int channel_axis = -1; + if (!data_sinfo->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + // channel_axis must be in axes. + if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " expects that channel_axis must not be in axes, but got channel_axis: " + << channel_axis << ", axes: " << attrs->axes); + } + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that data must be float, but got " << data_sinfo->dtype); + } + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape != nullptr && channel_axis != -1 && + analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of channel_axis must be divisible by " + << attrs->num_groups << ", but got " << data_shape->values[channel_axis]); + } + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } else if (channel_axis != -1) { + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); + } + } + } + } + return data_sinfo; +} + +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoGroupNorm); + /* relax.nn.dropout */ TVM_REGISTER_NODE_TYPE(DropoutAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index f13b930fc246..f578f89346f7 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -68,6 +68,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, bool scale); +/*! \brief Compute group normalization. */ +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale); + /*! * \brief Applies the dropout operation to the input tensor. * \param data The input data to the operator. diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index ba3c930a456f..c21dbd2bd1f5 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -362,7 +362,7 @@ def f( y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: - m = T.var("int64") + m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) w: R.Tensor = R.multiply(z, z) q: R.Tensor(ndim=2) = R.add(w, w) @@ -431,7 +431,7 @@ def test_call_tir(): # also from test_parser @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 137713869e91..73cfacf1e526 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -708,29 +708,19 @@ def main( w1: R.Tensor((3,), dtype="float32"), w2: R.Tensor((3,), dtype="float32"), ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape( - input_1, (1, 3, 1, 10, 10) - ) - lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean( - lv, axis=[2, 3, 4], keepdims=True - ) - lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1) - lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2) - lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum( - lv3, axis=[2, 3, 4], keepdims=True + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, ) - lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0)) - lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05)) - lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6) - lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7) - lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1)) - lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1)) - lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9) - lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10) - lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10)) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13 + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 5294596cee34..51144784638a 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -849,6 +849,244 @@ def test_layer_norm_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) +def test_group_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor((4,))) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x3, gamma2, beta1, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_group_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((a,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((a,), "float32")) + beta = relax.Var("beta", R.Tensor((a,), "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_group_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_group_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float16")) + beta0 = relax.Var("beta", R.Tensor((3,), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "float64")) + beta1 = relax.Var("beta", R.Tensor((3,), "float64")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_group_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int32")) + beta1 = relax.Var("beta", R.Tensor((4,), "int32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, 4]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, -1]) + ) + + +def test_group_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + def test_dropout_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 698ad2727456..8fb398f15d2b 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1452,5 +1452,167 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r tvm.ir.assert_structural_equal(mod, Expected) +def test_group_norm(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(x: R.Tensor((2, 4, 4, 5), "float32"), gamma: R.Tensor((4,), "float32"), beta: R.Tensor((4,), "float32")) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((T.int64(4),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] + for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)] + + @R.function + def main(x: R.Tensor((2, 4, 4, 5), dtype="float32"), gamma: R.Tensor((4,), dtype="float32"), beta: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 4, 4, 5), dtype="float32"): + gv = R.call_tir(group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_group_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), "float32"), gamma: R.Tensor(("4 * c",), "float32"), beta: R.Tensor(("4 * c",), "float32")) -> R.Tensor(("n", "4 * c", "h", "w"), "float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv: R.Tensor((n, 4 * c, h, w), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=4, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_reshape: T.handle, c: T.int64): + T.func_attr({"tir.noalias": True}) + n = T.int64() + h = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, T.int64(4) * c, h, w)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4) * c,)) + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4) * c,)) + T_reshape = T.match_buffer(var_T_reshape, (n, T.int64(4) * c, h, w)) + # with T.block("root"): + T_reshape_1 = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + rxplaceholder_red_temp_v0 = T.alloc_buffer((n, T.int64(4))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((n, T.int64(4))) + T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_reshape_3 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_group_norm = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w] + for ax0, ax1, k2, k3, k4 in T.grid(n, T.int64(4), c, h, w): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w] + + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"), gamma: R.Tensor(("4 * c",), dtype="float32"), beta: R.Tensor(("4 * c",), dtype="float32")) -> R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv = R.call_tir(group_norm, (x, gamma, beta), out_sinfo=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c])) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 781700af7b82..c2bfa5b7a9e9 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -185,6 +185,31 @@ def foo( _check(foo, bb.get()["foo"]) +def test_group_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 4, 5), "float32"), + gamma: R.Tensor((4,), "float32"), + beta: R.Tensor((4,), "float32"), + ) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm( + x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3] + ) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_dropout(): @R.function def foo(