diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 694a510706839..61b1622a6082c 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode { } }; // struct LayerNormAttrs +/*! \brief Attributes used in group_norm operator */ +struct GroupNormAttrs : public tvm::AttrsNode { + int num_groups; + int channel_axis; + Array axes; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") { + TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into."); + TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel."); + TVM_ATTR_FIELD(axes).describe( + "The axes that along which the normalization is applied (excluding the channel axis)."); + TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).describe( + "Indicating if the beta offset will be added to the normalized tensor."); + TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); + } +}; // struct GroupNormAttrs + /*! \brief Attributes used in dropout operator */ struct DropoutAttrs : public tvm::AttrsNode { double rate; diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e80f73096c594..24fcf0caca649 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: ) def _group_norm(self, node: fx.node.Node) -> relax.Var: - # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, - # affine=True, device=None, dtype=None) + import torch # type: ignore + x = self.env[node.args[0]] module = self.named_modules[node.target] - num_groups = module.num_groups - num_channels = module.num_channels - eps = module.eps - affine = module.affine - shape = self.shape_of(x) - assert len(shape) == 4 - N, C, H, W = shape[0], shape[1], shape[2], shape[3] - assert C == num_channels - assert C % num_groups == 0 - grouped_x = self.block_builder.emit( - relax.op.reshape(x, [N, num_groups, C // num_groups, H, W]) - ) - mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True)) - sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x)) - square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x)) - sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) - var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) - var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) - std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps)) - norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x)) - - if affine: - weight = self.params[module.weight] - bias = self.params[module.bias] - weight_reshape = self.block_builder.emit( - relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1)) - ) - bias_reshape = self.block_builder.emit( - relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1)) + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=module.num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=module.eps, ) - norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape)) - norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape)) - return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W))) + ) def _embedding(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 2fef37249703b..bbb1268f1c963 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -527,6 +527,64 @@ def layer_norm( return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore +def group_norm( + data: Expr, + gamma: Expr, + beta: Expr, + num_groups: int, + channel_axis: int, + axes: Union[int, List[int]], + epsilon: float = 1e-5, + center: bool = True, + scale: bool = True, +) -> Expr: + r""" + Group normalization (Yuxin Wu and et al., 2016). + Applies group normalization to the n-dimensional input array. + This operator takes an n-dimensional input array. First separate the input array + into groups along the channel axis. Then apply layer normalization to each group. + + Parameters + ---------- + data : relax.Expr + Input to which group_norm will be applied. + + gamma : relax.Expr + The gamma scale factor. + + beta : relax.Expr + The beta offset factor. + + num_groups : int + Number of groups to separate the channels into. + + channel_axis : int + The index of the channel axis in the input data. + + axes : Union[int, List[int]] + The axes that along which the normalization is applied (excluding the group axis) + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + Indicating if the beta offset will be added to the normalized tensor. + + scale : bool + Indicating if the gamma scale will be multiplied. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axes, int): + axes = [axes] + return _ffi_api.group_norm( # type: ignore + data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale + ) + + def dropout(data: Expr, rate: float = 0.5) -> Expr: """Applies the dropout operation to the input tensor. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 70bb2513dda36..a61e0cd09ee18 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.group_norm") +def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.group_norm, + call.args[0], + call.args[1], + call.args[2], + call.attrs.num_groups, + call.attrs.channel_axis, + call.attrs.axes, + call.attrs.epsilon, + ) + + @register_legalize("relax.nn.dropout") def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index e63b3306f25d1..430d2268cec31 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -233,6 +233,89 @@ TVM_REGISTER_OP("relax.nn.layer_norm") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm); +/* relax.nn.group_norm */ +TVM_REGISTER_NODE_TYPE(GroupNormAttrs); + +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale) { + ObjectPtr attrs = make_object(); + attrs->num_groups = num_groups; + attrs->channel_axis = channel_axis; + attrs->axes = std::move(axes); + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + + static const Op& op = Op::Get("relax.nn.group_norm"); + return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm); + +StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { + Op op = Downcast(call->op); + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + TensorStructInfo data_sinfo = input_sinfo[0]; + int channel_axis = -1; + if (!data_sinfo->IsUnknownNdim()) { + channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis); + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + // channel_axis must be in axes. + if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op + << " expects that channel_axis must not be in axes, but got channel_axis: " + << channel_axis << ", axes: " << attrs->axes); + } + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that data must be float, but got " << data_sinfo->dtype); + } + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape != nullptr && channel_axis != -1 && + analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of channel_axis must be divisible by " + << attrs->num_groups << ", but got " << data_shape->values[channel_axis]); + } + for (int i = 1; i < static_cast(op->arguments.size()); ++i) { + if (input_sinfo[i]->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have the same dtype, but got " + << input_sinfo[i]->dtype << " and " << data_sinfo->dtype); + } else if (input_sinfo[i]->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that all inputs must have ndim=1, but got " + << input_sinfo[i]->ndim); + } else if (channel_axis != -1) { + const auto* shape = input_sinfo[i]->shape.as(); + if (shape != nullptr && data_shape != nullptr) { + PrimExpr channel_size = data_shape->values[channel_axis]; + PrimExpr input_size = shape->values[0]; + if (analyzer->CanProve(channel_size != input_size)) { + ctx->ReportFatal(Diagnostic::Error(call) + << op << " expects that the size of input " << i + << " must be equal to the size of channel_axis, but got " << input_size + << " and " << channel_size); + } + } + } + } + return data_sinfo; +} + +TVM_REGISTER_OP("relax.nn.group_norm") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_attr("FInferStructInfo", InferStructInfoGroupNorm); + /* relax.nn.dropout */ TVM_REGISTER_NODE_TYPE(DropoutAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index f13b930fc246c..f578f89346f79 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -68,6 +68,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, bool scale); +/*! \brief Compute group normalization. */ +Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis, + Array axes, double epsilon, bool center, bool scale); + /*! * \brief Applies the dropout operation to the input tensor. * \param data The input data to the operator. diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index ba3c930a456fb..c21dbd2bd1f51 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -362,7 +362,7 @@ def f( y: R.Tensor(("m",), "float32"), r: R.Tensor(dtype="int64"), ) -> R.Object: - m = T.var("int64") + m = T.int64() z: R.Tensor((32, m), "float32") = R.multiply(x, y) w: R.Tensor = R.multiply(z, z) q: R.Tensor(ndim=2) = R.add(w, w) @@ -431,7 +431,7 @@ def test_call_tir(): # also from test_parser @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.var("int64"), T.var("int64") + m, n = T.int64(), T.int64() gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 137713869e919..73cfacf1e526a 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -708,29 +708,19 @@ def main( w1: R.Tensor((3,), dtype="float32"), w2: R.Tensor((3,), dtype="float32"), ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape( - input_1, (1, 3, 1, 10, 10) - ) - lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean( - lv, axis=[2, 3, 4], keepdims=True - ) - lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1) - lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2) - lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum( - lv3, axis=[2, 3, 4], keepdims=True + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, ) - lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0)) - lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05)) - lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6) - lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7) - lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1)) - lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1)) - lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9) - lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10) - lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10)) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13 + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 5294596cee340..51144784638af 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -849,6 +849,244 @@ def test_layer_norm_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1])) +def test_group_norm_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1)) + gamma2 = relax.Var("gamma", R.Tensor((4,))) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x3, gamma2, beta1, num_groups=2, channel_axis=-2, axes=[-1]), + relax.TensorStructInfo((2, 3, 4, 5), dtype=""), + ) + + +def test_group_norm_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c0 = tir.Var("c", "int64") + c1 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) + x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=4)) + gamma0 = relax.Var("gamma", R.Tensor((a,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((a,), "float32")) + beta = relax.Var("beta", R.Tensor((a,), "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c1), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo((n, a, b, c0), "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma0, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x2, gamma1, beta, num_groups=2, channel_axis=-3, axes=[-2, -1]), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + + +def test_group_norm_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo()) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32")) + beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma, beta, num_groups=2, channel_axis=-2, axes=[1, 3]), + relax.TensorStructInfo(s1, "float32"), + ) + + +def test_group_norm_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + gamma0 = relax.Var("gamma", R.Tensor((3,), "float16")) + beta0 = relax.Var("beta", R.Tensor((3,), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64")) + gamma1 = relax.Var("gamma", R.Tensor((3,), "float64")) + beta1 = relax.Var("beta", R.Tensor((3,), "float64")) + + _check_inference( + bb, + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float16"), + ) + _check_inference( + bb, + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=3, channel_axis=1, axes=[-2, -1]), + relax.TensorStructInfo((2, 3, 4, 5), "float64"), + ) + + +def test_group_norm_infer_struct_info_invalid_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int32")) + beta1 = relax.Var("beta", R.Tensor((4,), "int32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_axis_out_of_range_and_repetitive(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, 4]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=-2, axes=[3, -1]) + ) + + +def test_group_norm_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4,), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "int8")) + beta0 = relax.Var("beta", R.Tensor((4,), "float32")) + beta1 = relax.Var("beta", R.Tensor((4,))) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4,), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma1, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x, gamma0, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_shape_mismatch(): + bb = relax.BlockBuilder() + c0 = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) + gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) + gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32")) + beta0 = relax.Var("beta", R.Tensor((4, 5), "float32")) + beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma0, beta0, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma1, beta1, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + +def test_group_norm_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32")) + gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), "float32"))) + beta = relax.Var("beta", R.Tensor((4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x1, gamma0, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + with pytest.raises(TVMError): + bb.normalize( + relax.op.nn.group_norm(x0, gamma1, beta, num_groups=2, channel_axis=-2, axes=[-2, -1]) + ) + + def test_dropout_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3), "float32")) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 698ad2727456d..8fb398f15d2b1 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1452,5 +1452,167 @@ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_r tvm.ir.assert_structural_equal(mod, Expected) +def test_group_norm(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(x: R.Tensor((2, 4, 4, 5), "float32"), gamma: R.Tensor((4,), "float32"), beta: R.Tensor((4,), "float32")) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((T.int64(4),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32")): + T.func_attr({"tir.noalias": True}) + T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2), T.int64(2))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2))) + T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)] + for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1 in T.grid(T.int64(2), T.int64(2)): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) % T.int64(4)] + for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] * T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(4), T.int64(5)): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4), v_ax3 % T.int64(5)] + + @R.function + def main(x: R.Tensor((2, 4, 4, 5), dtype="float32"), gamma: R.Tensor((4,), dtype="float32"), beta: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 4, 4, 5), dtype="float32"): + gv = R.call_tir(group_norm, (x, gamma, beta), out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_group_norm_symbolic(): + # fmt: off + @tvm.script.ir_module + class GroupNorm: + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), "float32"), gamma: R.Tensor(("4 * c",), "float32"), beta: R.Tensor(("4 * c",), "float32")) -> R.Tensor(("n", "4 * c", "h", "w"), "float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv: R.Tensor((n, 4 * c, h, w), "float32") = R.nn.group_norm(x, gamma, beta, num_groups=4, channel_axis=1, axes=[2, 3]) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func + def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_reshape: T.handle, c: T.int64): + T.func_attr({"tir.noalias": True}) + n = T.int64() + h = T.int64() + w = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (n, T.int64(4) * c, h, w)) + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4) * c,)) + rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4) * c,)) + T_reshape = T.match_buffer(var_T_reshape, (n, T.int64(4) * c, h, w)) + # with T.block("root"): + T_reshape_1 = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + rxplaceholder_red_temp_v0 = T.alloc_buffer((n, T.int64(4))) + rxplaceholder_red_temp_v1 = T.alloc_buffer((n, T.int64(4))) + T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_reshape_3 = T.alloc_buffer((T.int64(4), T.int64(4) * c // T.int64(4))) + T_group_norm = T.alloc_buffer((n, T.int64(4), T.int64(4) * c // T.int64(4), h, w)) + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = rxplaceholder[((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h // (c * T.int64(4)) % n, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w // h % (c * T.int64(4)), ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) // w % h, ((((v_ax0 * T.int64(4) + v_ax1) * c + v_ax2) * h + v_ax3) * w + v_ax4) % w] + for ax0, ax1, k2, k3, k4 in T.grid(n, T.int64(4), c, h, w): + with T.block("rxplaceholder_red_temp"): + v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR", [ax0, ax1, k2, k3, k4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4]) + T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1]) + with T.init(): + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0) + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0) + v_rxplaceholder_red_temp_v0: T.float32 = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + v_rxplaceholder_red_temp_v1: T.float32 = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] * T_reshape_1[v_ax0, v_ax1, v_k2, v_k3, v_k4] + rxplaceholder_red_temp_v0[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v0 + rxplaceholder_red_temp_v1[v_ax0, v_ax1] = v_rxplaceholder_red_temp_v1 + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_2[v_ax0, v_ax1]) + T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1 in T.grid(T.int64(4), c): + with T.block("T_reshape_2"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))]) + T.writes(T_reshape_3[v_ax0, v_ax1]) + T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 * c + v_ax1) % (c * T.int64(4))] + for ax0, ax1, ax2, ax3, ax4 in T.grid(n, T.int64(4), c, h, w): + with T.block("T_group_norm"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2]) + T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / (T.Cast("float32", c) * T.Cast("float32", h) * T.Cast("float32", w))) + T.float32(1.0000000000000001e-05)) * T_reshape_2[v_ax1, v_ax2] + T_reshape_3[v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w): + with T.block("T_reshape_3"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w] + + @R.function + def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"), gamma: R.Tensor(("4 * c",), dtype="float32"), beta: R.Tensor(("4 * c",), dtype="float32")) -> R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"): + n = T.int64() + c = T.int64() + h = T.int64() + w = T.int64() + gv = R.call_tir(group_norm, (x, gamma, beta), out_sinfo=R.Tensor((n, 4 * c, h, w), dtype="float32"), tir_vars=R.shape([c])) + return gv + # fmt: on + + mod = LegalizeOps()(GroupNorm) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index 781700af7b82d..c2bfa5b7a9e99 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -185,6 +185,31 @@ def foo( _check(foo, bb.get()["foo"]) +def test_group_norm(): + @R.function + def foo( + x: R.Tensor((2, 4, 4, 5), "float32"), + gamma: R.Tensor((4,), "float32"), + beta: R.Tensor((4,), "float32"), + ) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.group_norm( + x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3] + ) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32")) + gamma = relax.Var("gamma", R.Tensor((4,), "float32")) + beta = relax.Var("beta", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, gamma, beta]): + gv = bb.emit( + relax.op.nn.group_norm(x, gamma, beta, num_groups=2, channel_axis=1, axes=[2, 3]) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_dropout(): @R.function def foo(