Skip to content

Commit

Permalink
[Unity][Op] Group normalization (#14194)
Browse files Browse the repository at this point in the history
* [TOPI] Group normalization

As more and more ML models nowadays contain the group normalization
computation, we find it beneficial to introduce this op to TOPI level.
It will enable us to optimize the group normalization operation as a
whole in a more convenient way.

This PR introduces the group normalization op to TOPI. The group norm
operation was introduced in https://arxiv.org/abs/1803.08494. The
implementation uses tuple reduction, same as the implementation of layer
norm. Implemented with tuple reduction, the corresponding generated TIR
function can be optimized by cross-thread reduction or rfactor through
MetaSchedule.

Prior to this PR, the group normalization operations in frontend models
are translated to a series of operations, which brings inconvenience
when we want to optimize the group norm op as a whole.

With the TOPI implementation of group norm being introduced by #14193,
we can now use it to legalize the high-level group norm op and optimize
it using cross-thread reduction or rfactor via MetaSchedule.


Co-authored-by: Bohan Hou <[email protected]>
  • Loading branch information
2 people authored and tqchen committed Apr 1, 2023
1 parent 3f66edc commit 283a3db
Show file tree
Hide file tree
Showing 11 changed files with 638 additions and 57 deletions.
21 changes: 21 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
}
}; // struct LayerNormAttrs

/*! \brief Attributes used in group_norm operator */
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
int num_groups;
int channel_axis;
Array<Integer> axes;
double epsilon;
bool center;
bool scale;

TVM_DECLARE_ATTRS(GroupNormAttrs, "relax.attrs.GroupNormAttrs") {
TVM_ATTR_FIELD(num_groups).describe("The number of groups to separate the channels into.");
TVM_ATTR_FIELD(channel_axis).describe("The axis that represents the channel.");
TVM_ATTR_FIELD(axes).describe(
"The axes that along which the normalization is applied (excluding the channel axis).");
TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).describe(
"Indicating if the beta offset will be added to the normalized tensor.");
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
}
}; // struct GroupNormAttrs

/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
Expand Down
54 changes: 20 additions & 34 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,44 +465,30 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var:
)

def _group_norm(self, node: fx.node.Node) -> relax.Var:
# torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05,
# affine=True, device=None, dtype=None)
import torch # type: ignore

x = self.env[node.args[0]]
module = self.named_modules[node.target]
num_groups = module.num_groups
num_channels = module.num_channels
eps = module.eps
affine = module.affine

shape = self.shape_of(x)
assert len(shape) == 4
N, C, H, W = shape[0], shape[1], shape[2], shape[3]
assert C == num_channels
assert C % num_groups == 0
grouped_x = self.block_builder.emit(
relax.op.reshape(x, [N, num_groups, C // num_groups, H, W])
)
mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True))
sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x))
square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x))
sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True))
var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value)
var_x_eps = self._call_binary_op(relax.op.add, var_x, eps)
std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps))
norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x))

if affine:
weight = self.params[module.weight]
bias = self.params[module.bias]
weight_reshape = self.block_builder.emit(
relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1))
)
bias_reshape = self.block_builder.emit(
relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1))
if module.affine:
gamma = self.params[module.weight]
beta = self.params[module.bias]
else:
gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type)
beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type)

dim = len(self.shape_of(x))
return self.block_builder.emit(
relax.op.nn.group_norm(
x,
gamma,
beta,
num_groups=module.num_groups,
channel_axis=1,
axes=list(range(2, dim)),
epsilon=module.eps,
)
norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape))
norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape))
return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W)))
)

def _embedding(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
58 changes: 58 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,64 @@ def layer_norm(
return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, scale) # type: ignore


def group_norm(
data: Expr,
gamma: Expr,
beta: Expr,
num_groups: int,
channel_axis: int,
axes: Union[int, List[int]],
epsilon: float = 1e-5,
center: bool = True,
scale: bool = True,
) -> Expr:
r"""
Group normalization (Yuxin Wu and et al., 2016).
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array. First separate the input array
into groups along the channel axis. Then apply layer normalization to each group.
Parameters
----------
data : relax.Expr
Input to which group_norm will be applied.
gamma : relax.Expr
The gamma scale factor.
beta : relax.Expr
The beta offset factor.
num_groups : int
Number of groups to separate the channels into.
channel_axis : int
The index of the channel axis in the input data.
axes : Union[int, List[int]]
The axes that along which the normalization is applied (excluding the group axis)
epsilon : float
Small float added to variance to avoid dividing by zero.
center : bool
Indicating if the beta offset will be added to the normalized tensor.
scale : bool
Indicating if the gamma scale will be multiplied.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axes, int):
axes = [axes]
return _ffi_api.group_norm( # type: ignore
data, gamma, beta, num_groups, channel_axis, axes, epsilon, center, scale
)


def dropout(data: Expr, rate: float = 0.5) -> Expr:
"""Applies the dropout operation to the input tensor.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.group_norm")
def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.group_norm,
call.args[0],
call.args[1],
call.args[2],
call.attrs.num_groups,
call.attrs.channel_axis,
call.attrs.axes,
call.attrs.epsilon,
)


@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
Expand Down
83 changes: 83 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,89 @@ TVM_REGISTER_OP("relax.nn.layer_norm")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm);

/* relax.nn.group_norm */
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);

Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale) {
ObjectPtr<GroupNormAttrs> attrs = make_object<GroupNormAttrs>();
attrs->num_groups = num_groups;
attrs->channel_axis = channel_axis;
attrs->axes = std::move(axes);
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;

static const Op& op = Op::Get("relax.nn.group_norm");
return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, Attrs{attrs}, {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.group_norm").set_body_typed(group_norm);

StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) {
Op op = Downcast<Op>(call->op);
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<GroupNormAttrs>();

TensorStructInfo data_sinfo = input_sinfo[0];
int channel_axis = -1;
if (!data_sinfo->IsUnknownNdim()) {
channel_axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->channel_axis);
std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes);
// channel_axis must be in axes.
if (std::find(axes.begin(), axes.end(), channel_axis) != axes.end()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op
<< " expects that channel_axis must not be in axes, but got channel_axis: "
<< channel_axis << ", axes: " << attrs->axes);
}
}
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that data must be float, but got " << data_sinfo->dtype);
}
arith::Analyzer* analyzer = ctx->GetAnalyzer();
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape != nullptr && channel_axis != -1 &&
analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that the size of channel_axis must be divisible by "
<< attrs->num_groups << ", but got " << data_shape->values[channel_axis]);
}
for (int i = 1; i < static_cast<int>(op->arguments.size()); ++i) {
if (input_sinfo[i]->dtype != data_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that all inputs must have the same dtype, but got "
<< input_sinfo[i]->dtype << " and " << data_sinfo->dtype);
} else if (input_sinfo[i]->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that all inputs must have ndim=1, but got "
<< input_sinfo[i]->ndim);
} else if (channel_axis != -1) {
const auto* shape = input_sinfo[i]->shape.as<ShapeExprNode>();
if (shape != nullptr && data_shape != nullptr) {
PrimExpr channel_size = data_shape->values[channel_axis];
PrimExpr input_size = shape->values[0];
if (analyzer->CanProve(channel_size != input_size)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " expects that the size of input " << i
<< " must be equal to the size of channel_axis, but got " << input_size
<< " and " << channel_size);
}
}
}
}
return data_sinfo;
}

TVM_REGISTER_OP("relax.nn.group_norm")
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm);

/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);

Expand Down
4 changes: 4 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_
Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center,
bool scale);

/*! \brief Compute group normalization. */
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale);

/*!
* \brief Applies the dropout operation to the input tensor.
* \param data The input data to the operator.
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def f(
y: R.Tensor(("m",), "float32"),
r: R.Tensor(dtype="int64"),
) -> R.Object:
m = T.var("int64")
m = T.int64()
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
w: R.Tensor = R.multiply(z, z)
q: R.Tensor(ndim=2) = R.add(w, w)
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_call_tir():
# also from test_parser
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")):
m, n = T.var("int64"), T.var("int64")
m, n = T.int64(), T.int64()
gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32"))
return gv0

Expand Down
32 changes: 11 additions & 21 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,29 +708,19 @@ def main(
w1: R.Tensor((3,), dtype="float32"),
w2: R.Tensor((3,), dtype="float32"),
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape(
input_1, (1, 3, 1, 10, 10)
)
lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean(
lv, axis=[2, 3, 4], keepdims=True
)
lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1)
lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2)
lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum(
lv3, axis=[2, 3, 4], keepdims=True
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm(
input_1,
w1,
w2,
num_groups=3,
channel_axis=1,
axes=[2, 3],
epsilon=1.0000000000000001e-05,
center=True,
scale=True,
)
lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0))
lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05))
lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6)
lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7)
lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1))
lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1))
lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9)
lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10)
lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10))
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

Expand Down
Loading

0 comments on commit 283a3db

Please sign in to comment.