From 70c5b0bd0c36757b07408a98cd4cef244a658006 Mon Sep 17 00:00:00 2001 From: Anna Karbownik Date: Mon, 5 Oct 2020 10:13:55 +0200 Subject: [PATCH 1/3] Fix MKLDNN BatchNorm with even number of channels (#19150) Even number of channels results in data reordering before batch norm operation. Therefore, if BatchNorm data array is view of another array and the data is stored in MKLDNN format, the data needs to be converted to the default format. --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 0a29a6d87de6..75c7c4dbf38a 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -145,13 +145,6 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, return it->second; } -template -static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, - const OpContext &ctx, const NDArray &in_data, - mkldnn::normalization_flags flags) { - return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); -} - template void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -182,8 +175,11 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, aux_states, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); + NDArray &data = in_data[batchnorm::kData]; + if (data.IsMKLDNNData() && data.IsView()) + data = data.Reorder2Default(); + auto data_mem = data.GetMKLDNNData(); + auto &fwd = GetBNForward(param, ctx, data_mem, flags); // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -221,7 +217,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; - net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SRC] = *data_mem; net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; net_args[MKLDNN_ARG_DST] = *out_mem; if (fuse_relu) { From 62b58d8aad0bdea77534ef2952aa699f73f5c31d Mon Sep 17 00:00:00 2001 From: Anna Karbownik Date: Thu, 22 Oct 2020 16:32:30 +0200 Subject: [PATCH 2/3] Add or updated test to verify Batchnorm odd & even number of channels --- tests/python/mkl/test_mkldnn.py | 2 +- tests/python/unittest/test_gluon.py | 35 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 213bcfb1edad..cf2ca13c161a 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -294,7 +294,7 @@ def test_mkldnn_sum_inplace_with_cpu_layout(): @with_seed() def test_batchnorm(): def check_batchnorm_training(stype): - for shape in [(2, 3), (2, 3, 2, 2)]: + for shape in [(2, 3), (2, 4), (2, 3, 2, 2), (2, 4, 2, 2)]: data_tmp = np.random.normal(-0.1, 0.1, size=shape) s = shape[1], gamma = np.ones(s) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 7bacb4f0b317..46b1e41e66b4 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -20,6 +20,7 @@ import mxnet as mx from mxnet import gluon +from mxnet import init from mxnet.gluon import nn from mxnet.base import py_str, MXNetError from mxnet.test_utils import assert_almost_equal, default_context @@ -2179,6 +2180,40 @@ def hybrid_forward(self, F, x): check_layer_forward_withinput(net, x) +@with_seed() +def test_batchnorm_chnls(): + chn_list = [1024, 512, 256, 128, 64, 45, 32, 16, 3] + class Net(gluon.HybridBlock): + def __init__(self, + chn_num, + norm_kwargs=None, + in_channels=3, + **kwargs): + super(Net, self).__init__(**kwargs) + self.in_channels = in_channels + self.conv1 = gluon.nn.Conv3D( + in_channels=self.in_channels, + channels=chn_num, + kernel_size=(1, 7, 7), + strides=(1, 2, 2), + padding=(0, 3, 3), + use_bias=False, + ) + self.bn1 = gluon.nn.BatchNorm(in_channels=chn_num, **({} if norm_kwargs is None else norm_kwargs)) + + def hybrid_forward(self, F, x): + """Hybrid forward of R2+1D net""" + conv = self.conv1(x) + out = self.bn1(conv) + return out + + for i in range(len(chn_list)): + net = Net(chn_list[i]) + net.initialize(init=init.Constant(1)) + x = mx.nd.zeros((1, 3, 8, 160, 160), ctx=mx.cpu()) + net(x).asnumpy() + + @with_seed() def test_concat(): chn_list = [16, 64] From 2c11470b8336c1143e1c90dd9b39d7c90a6654ae Mon Sep 17 00:00:00 2001 From: Anna Karbownik Date: Fri, 23 Oct 2020 17:44:24 +0200 Subject: [PATCH 3/3] Fix for Batchnorm odd & even chnls number context --- tests/python/unittest/test_gluon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 46b1e41e66b4..49b84a2b9d68 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -2210,7 +2210,7 @@ def hybrid_forward(self, F, x): for i in range(len(chn_list)): net = Net(chn_list[i]) net.initialize(init=init.Constant(1)) - x = mx.nd.zeros((1, 3, 8, 160, 160), ctx=mx.cpu()) + x = mx.nd.zeros((1, 3, 8, 160, 160)) net(x).asnumpy()