From c7af4501b86e69b6da47ca65263320766f43152f Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 13 Jul 2018 01:31:59 -0700 Subject: [PATCH] fix batchnorm problem with sparse matrices when fix_gamma=True (#11656) --- src/operator/batch_norm_v1.cc | 3 + src/operator/nn/batch_norm.cc | 46 ++++++++---- src/operator/nn/fully_connected.cc | 4 +- tests/python/mkl/test_mkldnn.py | 2 +- tests/python/unittest/test_operator.py | 4 +- tests/python/unittest/test_sparse_operator.py | 75 ++++++++++++++++++- 6 files changed, 111 insertions(+), 23 deletions(-) diff --git a/src/operator/batch_norm_v1.cc b/src/operator/batch_norm_v1.cc index 5da4af253681..2d19107eda1e 100644 --- a/src/operator/batch_norm_v1.cc +++ b/src/operator/batch_norm_v1.cc @@ -89,6 +89,9 @@ the output. It is often used during inference. Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, then set ``gamma`` to 1 and its gradient to 0. +There's no sparse support for this operator, and it will exhibit problematic behavior if used with +sparse tensors. + )code" ADD_FILELINE) .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 1f9e8289f4a4..30fb665dd05a 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -321,6 +321,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, const BatchNormParam& param = nnvm::get(attrs.parsed); using namespace mshadow; CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; + CHECK_EQ(out_shape->size(), 3U); const TShape &dshape = in_shape->at(batchnorm::kData); const size_t channelAxis = static_cast(param.axis < 0 @@ -444,27 +445,37 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, } FallBackCompute(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); } +#endif static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 5); - CHECK_EQ(out_attrs->size(), 3); - return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, - in_attrs, out_attrs); -} + const BatchNormParam ¶m = nnvm::get(attrs.parsed); -static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs, - const int dev_mask, - DispatchMode *dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, - in_attrs, out_attrs); -} + bool dispatched = false; +#if MXNET_USE_MKLDNN == 1 + if (!dispatched) { + dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, + in_attrs, out_attrs); + } +#else + for (int& v : *in_attrs) + if (v == - 1) v = kDefaultStorage; + if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { + dispatched = storage_type_assign(out_attrs, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } #endif + if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) { + LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647"; + } + return dispatched; +} std::vector BatchNormGrad(const nnvm::NodePtr& n, const std::vector& ograds) { @@ -552,6 +563,11 @@ axis to be the last item in the input shape. Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, then set ``gamma`` to 1 and its gradient to 0. +Note:: + +When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False, +the sparse tensors will fallback. + )code" ADD_FILELINE) .set_num_inputs(5) .set_num_outputs(3) @@ -574,9 +590,7 @@ then set ``gamma`` to 1 and its gradient to 0. }) .set_attr("FInferShape", BatchNormShape) .set_attr("FInferType", BatchNormType) -#if MXNET_USE_MKLDNN == 1 .set_attr("FInferStorageType", BatchNormStorageType) -#endif .set_attr("FCompute", BatchNormCompute) #if MXNET_USE_MKLDNN == 1 .set_attr("FComputeEx", BatchNormComputeExCPU) @@ -607,8 +621,8 @@ then set ``gamma`` to 1 and its gradient to 0. NNVM_REGISTER_OP(_backward_BatchNorm) .set_num_outputs(3) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BatchNormStorageType) #if MXNET_USE_MKLDNN == 1 -.set_attr("FInferStorageType", backward_BatchNormStorageType) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index d9099cb57d4f..d1d84e975290 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -213,8 +213,8 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. // It seems there is a bug. if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) { - storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); + dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); } if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index a6d7743e9261..ff9ba538b95e 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -233,7 +233,7 @@ def check_batchnorm_training(stype): mx.nd.array(beta).tostype(stype)] mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] - test = mx.symbol.BatchNorm(data, fix_gamma=True) + test = mx.symbol.BatchNorm(data, fix_gamma=False) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) stypes = ['row_sparse', 'default'] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f9dde2e6d245..6c6ff310519d 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1552,9 +1552,7 @@ def check_batchnorm_training(stype): test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) - stypes = ['default'] - for stype in stypes: - check_batchnorm_training(stype) + check_batchnorm_training('default') @with_seed() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 95689b785db6..e51a49424c8a 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -16,7 +16,8 @@ # under the License. from mxnet.test_utils import * -from common import setup_module, with_seed, teardown +from mxnet.base import MXNetError +from common import setup_module, with_seed, teardown, assertRaises import random import warnings @@ -2098,6 +2099,78 @@ def check_scatter_ops(name, shape, lhs_stype, rhs_stype, forward_mxnet_call, for lambda l, r: l + r, rhs_is_scalar=True, verbose=False, density=0.5) + +@with_seed() +def test_batchnorm_fallback(): + # same test as test_operator.test_batchnorm_training, but tests fallback logic of batchnorm + stype = 'row_sparse' + for shape in [(2, 3), (2, 3, 2, 2)]: + data_tmp = np.random.normal(-0.1, 0.1, size=shape) + s = shape[1], + gamma = np.ones(s) + beta = np.ones(s) + gamma[1] = 3 + beta[0] = 3 + + rolling_mean = np.random.uniform(size=s) + rolling_std = np.random.uniform(size=s) + + data = mx.symbol.Variable('data', stype=stype) + in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype), + mx.nd.array(beta).tostype(stype)] + mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] + + test = mx.symbol.BatchNorm(data, fix_gamma=True) + assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) + + test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True) + assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) + + test = mx.symbol.BatchNorm(data, fix_gamma=False) + check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) + + test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True) + check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) + + # Test varying channel axis + dim = len(shape) + for chaxis in range(-dim, dim): + chaxis_true = chaxis + if chaxis < 0: + chaxis_true = dim + chaxis + + shapex = shape + + channel_count = shapex[chaxis_true] + data_tmp = np.random.normal(-0.1, 0.1, size=shapex) + + gamma = np.ones(channel_count) + beta = np.ones(channel_count) + if channel_count > 1: + gamma[1] = 3 + beta[0] = 3 + + in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype), + mx.nd.array(beta).tostype(stype)] + + xrolling_mean = np.random.uniform(size=channel_count) + xrolling_std = np.random.uniform(size=channel_count) + xmean_std = [mx.nd.array(xrolling_mean).tostype(stype), + mx.nd.array(xrolling_std).tostype(stype)] + + test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis) + assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) + + test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis) + assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) + + test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis) + check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) + + test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) + check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) + + @with_seed() def test_mkldnn_sparse(): # This test is trying to create a race condition describedd in