Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Add generic batch norm #9694

Merged
merged 11 commits into from
Dec 13, 2021
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,10 @@ def _impl_v1(cls, inputs, attr, params):
op_name="batch_norm",
ignores=["spatial", "is_test", "consumed_inputs", "momentum", "training_mode"],
)(inputs, attr, params)
return out[0]
# We only support test mode, so we return data, moving_mean, moving_var,
# and then moving_mean and moving_var again as placeholders for
# the expected "saved_mean", "saved_var".
return _expr.TupleWrapper(_expr.Tuple((*out, out[1], out[2])), 5)


class InstanceNorm(OnnxOpConverter):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ def legalize_batch_matmul(attrs, inputs, types):
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# batch_norm
reg.register_strategy("nn.batch_norm", strategy.batch_norm_strategy)
reg.register_pattern("nn.batch_norm", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# sparse_dense
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type):
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,29 @@ def batch_matmul_strategy(attrs, inputs, out_type, target):
return strategy


# batch_norm
def wrap_compute_batch_norm(topi_compute):
"""wrap batch_norm topi compute"""

def _compute_batch_norm(attrs, inputs, out_type):
return topi_compute(*inputs, attrs.axis, attrs.epsilon, attrs.center, attrs.scale)

return _compute_batch_norm


@override_native_generic_func("batch_norm_strategy")
def batch_norm_strategy(attrs, inputs, out_type, target):
"""batch_norm generic strategy"""
logger.warning("batch_norm is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_norm(topi.nn.batch_norm),
wrap_topi_schedule(topi.generic.schedule_batch_norm),
name="batch_norm.generic",
)
return strategy


# sparse dense
def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,23 @@ def schedule_batch_matmul(outs):
return _default_schedule(outs, False)


def schedule_batch_norm(outs):
"""Schedule for batch_norm
Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_transpose
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_correlation_nchw(outs):
"""Schedule for correlation_nchw
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .bitserial_conv2d import *
from .bitserial_dense import *
from .batch_matmul import *
from .batch_norm import *
from .sparse import *
from .pad import *
from .fifo_buffer import *
Expand Down
110 changes: 110 additions & 0 deletions python/tvm/topi/nn/batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Batch normalization."""
import typing

from tvm import te
from tvm import topi


def batch_norm(
data: te.Tensor,
gamma: te.Tensor,
beta: te.Tensor,
moving_mean: te.Tensor,
moving_var: te.Tensor,
axis: typing.Optional[int] = None,
epsilon: typing.Optional[float] = None,
center: typing.Optional[bool] = None,
scale: typing.Optional[bool] = None,
) -> typing.List[te.Tensor]:
"""Batch normalization layer (Ioffe and Szegedy, 2014).

Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.

Parameters
----------
data : tvm.te.Tensor
Input to be batch-normalized.

gamma : tvm.te.Tensor
Scale factor to be applied to the normalized tensor.

beta : tvm.te.Tensor
Offset to be applied to the normalized tensor.

moving_mean : tvm.te.Tensor
Running mean of input.

moving_var : tvm.te.Tensor
Running variance of input.

axis : int, optional, default=1
Specify along which shape axis the normalization should occur.

epsilon : float, optional, default=1e-5
Small float added to variance to avoid dividing by zero.

center : bool, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.

scale : bool, optional, defualt=True
If True, scale normalized tensor by gamma. If False, gamma
is ignored.

Returns
-------
output : list of tvm.te.Tensor
Normalized data with same shape as input

moving_mean : tvm.te.Tensor
Running mean of input.

moving_var : tvm.te.Tensor
Running variance of input.
"""
if axis is None:
axis = 1

if epsilon is None:
epsilon = 1e-5

if center is None:
center = True

if scale is None:
scale = True

shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]

moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)

out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)

if scale:
out = out * topi.reshape(gamma, shape)
if center:
out = out + topi.reshape(beta, shape)

# Moving mean and var aren't updated during test. To avoid
# placeholder reuse, we multiply by 1 and return them.
return [out, moving_mean * 1, moving_var * 1]
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .batch_norm import batch_norm
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .poolnd_python import poolnd_python
Expand Down
89 changes: 89 additions & 0 deletions python/tvm/topi/testing/batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Batch Normalization implemented in Numpy."""
import numpy as np


def batch_norm(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
moving_mean: np.ndarray,
moving_var: np.ndarray,
axis: int,
epsilon: float,
center: bool,
scale: bool,
):
"""Batch Normalization operator implemented in Numpy.
Parameters
----------
data : np.ndarray
Input to be batch-normalized.
gamma : np.ndarray
Scale factor to be applied to the normalized tensor.
beta : np.ndarray
Offset to be applied to the normalized tensor.
moving_mean : np.ndarray
Running mean of input.
moving_var : np.ndarray
Running variance of input.
axis : int
Specify along which shape axis the normalization should occur.
epsilon : float
Small float added to variance to avoid dividing by zero.
center : bool
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : bool
If True, scale normalized tensor by gamma. If False, gamma
is ignored.
Returns
-------
output : np.ndarray
Normalized data with same shape as input
moving_mean : np.ndarray
Running mean of input.
moving_var : np.ndarray
Running variance of input.
"""
shape = [1] * len(x.shape)
shape[axis] = x.shape[axis]

moving_mean_rs = moving_mean.reshape(shape)
moving_var_rs = moving_var.reshape(shape)

out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)

if scale:
out = out * gamma.reshape(shape)
if center:
out = out + beta.reshape(shape)

return [out, moving_mean, moving_var]
3 changes: 2 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,8 @@ bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
reporter->Assign(types[4], TensorType({axis_size}, data->dtype));

// output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim)
// new running variance, saved mean and saved variance (the latter are all
// vectors of length dim)
std::vector<Type> fields;
auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}), data->dtype);
fields.push_back(TensorType(data->shape, data->dtype));
Expand Down
3 changes: 3 additions & 0 deletions src/topi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense)
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
.set_default(WrapSchedule(topi::generic::default_schedule));

TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm)
.set_default(WrapSchedule(topi::generic::default_schedule));

TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def test_batch_norm():
)
)

# axis=1
beta = relay.var("beta", relay.TensorType((3,), dtype))
gamma = relay.var("gamma", relay.TensorType((3,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype))
Expand Down Expand Up @@ -427,6 +428,53 @@ def test_batch_norm():
)


def test_batch_norm_fold_const():
axis = 1
dtype = "float32"
shape = [4, 5, 6]

data_np = np.random.random(shape).astype(dtype)
beta_np = np.random.random(shape[axis]).astype(dtype)
gamma_np = np.random.random(shape[axis]).astype(dtype)
moving_mean_np = np.random.random(shape[axis]).astype(dtype)
moving_var_np = np.random.random(shape[axis]).astype(dtype)

data = relay.var("data", relay.TensorType(shape, dtype))
beta = relay.var("beta", relay.TensorType((shape[1],), dtype))
gamma = relay.var("gamma", relay.TensorType((shape[1],), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((shape[1],), dtype))
moving_var = relay.var("moving_var", relay.TensorType((shape[1],), dtype))
out = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=axis).astuple()
func = relay.Function([data, gamma, beta, moving_mean, moving_var], out)

out_const = relay.nn.batch_norm(
relay.const(data_np),
relay.const(gamma_np),
relay.const(beta_np),
relay.const(moving_mean_np),
relay.const(moving_var_np),
axis=axis,
).astuple()
func_const = relay.Function([], out_const)

# Build the module with constants to have FoldConstant transform batch_norm.
mod_const = tvm.IRModule.from_expr(func_const)
mod_const = relay.transform.FoldConstant()(mod_const)

const_data_out = mod_const["main"].body[0].data
const_moving_mean_out = mod_const["main"].body[1].data
const_moving_var_out = mod_const["main"].body[2].data

# Run the Relay func without constants. This will use SimplyInference instead.
vm_data_out, vm_moving_mean_out, vm_moving_var_out = relay.create_executor(
"vm", device=tvm.device("llvm"), target="llvm"
).evaluate(func)(data_np, gamma_np, beta_np, moving_mean_np, moving_var_np)

tvm.testing.assert_allclose(const_data_out.numpy(), vm_data_out.numpy())
tvm.testing.assert_allclose(const_moving_mean_out.numpy(), vm_moving_mean_out.numpy())
tvm.testing.assert_allclose(const_moving_var_out.numpy(), vm_moving_var_out.numpy())


@pytest.mark.xfail
def test_matmul_type_check():
dtype = "float16"
Expand Down
Loading