Skip to content

Commit

Permalink
[TOPI] Add instance_norm operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Mar 28, 2023
1 parent 30bf013 commit ce7dea1
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 0 deletions.
63 changes: 63 additions & 0 deletions include/tvm/topi/nn/instance_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief instance normalization op constructions
* \file nn/instance_norm.h
*/
#ifndef TVM_TOPI_NN_INSTANCE_NORM_H_
#define TVM_TOPI_NN_INSTANCE_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/nn/layer_norm.h>
#include <tvm/topi/tags.h>

#include <string>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

/*!
* \brief Instance normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param gamma K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param beta Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over (the axis along which mean and variance are
* computed).
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tensor& beta,
const Array<Integer>& axis, double epsilon,
std::string name = "T_instance_norm", std::string tag = kInjective) {
return layer_norm(data, gamma, beta, axis, epsilon, name, tag);
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_INSTANCE_NORM_H_
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .bnn import *
from .qnn import *
from .upsampling import *
from .instance_norm import instance_norm
from .layer_norm import layer_norm
from .group_norm import group_norm
from .local_response_norm import *
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/topi/nn/instance_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Instance normalization operator."""
from .. import cpp


def instance_norm(data, gamma, beta, axis, epsilon=1e-5):
"""Instance normalization operator.
Parameters
----------
data : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
gamma: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
beta: tvm.te.Tensor
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
axis : list of int
Axis over the normalization applied (the axis along which the mean and variance are
computed)
epsilon : float
The epsilon value to avoid division by zero.
Returns
-------
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.instance_norm(data, gamma, beta, axis, epsilon)
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .reorg_python import reorg_python
from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python
from .roi_pool_python import roi_pool_nchw_python
from .instance_norm_python import instance_norm_python
from .layer_norm_python import layer_norm_python
from .group_norm_python import group_norm_python
from .lrn_python import lrn_python
Expand Down
53 changes: 53 additions & 0 deletions python/tvm/topi/testing/instance_norm_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Instance normalization in python"""
import numpy as np


def instance_norm_python(data, gamma, beta, axis, epsilon=1e-5):
"""Instance normalization operator in Python.
Parameters
----------
data : numpy.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
gamma: numpy.ndarray
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
beta: numpy.ndarray
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
axis : int or tuple of ints
Axis over the normalization applied
epsilon : float
The epsilon value to avoid division by zero.
Returns
-------
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
mean = np.mean(data, axis, keepdims=True)
var = np.var(data, axis, keepdims=True)
result = (data - mean) / np.sqrt(var + epsilon)
result *= gamma
if beta is not None:
result += beta
return result
6 changes: 6 additions & 0 deletions src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/topi/nn/dilate.h>
#include <tvm/topi/nn/flatten.h>
#include <tvm/topi/nn/group_norm.h>
#include <tvm/topi/nn/instance_norm.h>
#include <tvm/topi/nn/layer_norm.h>
#include <tvm/topi/nn/local_response_norm.h>
#include <tvm/topi/nn/mapping.h>
Expand Down Expand Up @@ -170,5 +171,10 @@ TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, TVMRetValue*
static_cast<int>(args[4]), args[5], static_cast<double>(args[6]));
});

/* Ops from nn/instance_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::instance_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
});

} // namespace topi
} // namespace tvm
64 changes: 64 additions & 0 deletions tests/python/topi/python/test_topi_instance_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for instance_norm."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
from tvm.topi.utils import get_const_tuple
import tvm.topi.testing

import tvm.testing


_instance_norm_schedule = {
"generic": topi.generic.schedule_injective,
}


# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2))])
def test_instance_norm(
target, dev, shape, axis, episilon=1e-5, dtype="float32", rtol=1e-5, atol=1e-5
):
data = te.placeholder(shape, dtype=dtype, name="data")
scale_shape = [shape[dim] for dim in axis]
gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma")
beta = te.placeholder(scale_shape, dtype=dtype, name="beta")
B = topi.nn.instance_norm(data, gamma, beta, axis, episilon)

data_np = np.random.uniform(size=shape).astype(dtype)
gamma_np = np.random.uniform(size=scale_shape).astype(dtype)
beta_np = np.random.uniform(size=scale_shape).astype(dtype)
b_np = tvm.topi.testing.instance_norm_python(data_np, gamma_np, beta_np, axis, episilon)

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _instance_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
gamma_tvm = tvm.nd.array(gamma_np, dev)
beta_tvm = tvm.nd.array(beta_np, dev)
b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
f = tvm.build(s, [data, gamma, beta, B], target)
f(data_tvm, gamma_tvm, beta_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ce7dea1

Please sign in to comment.