diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index 17a16db5adcd..485b3b33f6a8 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -259,6 +259,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param, const std::vector &outputs) { CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); + std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index df0357369fed..b8961df1b782 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3 &in_data, } } +template +static inline void ForEachFast(const BNTensor3 &in_data, + const BNTensor3 &in_data2, + const BNTensor3 &out_data, + const size_t channel, + OnData onData) { + const size_t num = in_data.OuterSize(); + const size_t matrixSize = in_data.InnerSize(); + const size_t skipLength = in_data.SkipLengthToNextSameChannelData(); + const size_t startOffset = in_data.StartOffset(channel); + + DType1 *data = in_data.dptr_ + startOffset; + DType2 *data2 = in_data2.dptr_ + startOffset; + DType3 *odata = out_data.dptr_ + startOffset; + + for (size_t outer = 0; outer < num; ++outer) { + for (size_t i = 0; i < matrixSize; ++i) { + onData(data++, data2++, odata++); + } + data += skipLength; + data2 += skipLength; + odata += skipLength; + } +} + } // namespace batchnorm /*! \brief Forward CPU */ @@ -264,7 +289,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, dotp += (*thisInputData - mean) * (*gradOut_data); }); - if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input + if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input if (is_train_and_not_global_stats) { // when in training mode // Q(X) = X - E[x] ; i.e. input centered to zero mean @@ -273,44 +298,60 @@ void BatchNormBackwardImpl(mshadow::Stream *, // projection of gradOutput on to output scaled by std const AccReal k = dotp * invstd * invstd / itemCount; - ForEachFast(inputData, gradIn, static_cast(channel), - [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { - *gradIn_data = (*inputDataPtr - mean) * k; - }); - const AccReal iw = invstd * w; const AccReal gradMean = sumGradOut / itemCount; - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; - }); + if (req[batchnorm::kData] != kAddTo) { + ForEachFast(inputData, gradIn, static_cast(channel), + [&mean, &k](const DType *inputDataPtr, DType *gradIn_data) { + *gradIn_data = (*inputDataPtr - mean) * k; + }); + + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw, gradMean](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw; + }); + } else { + ForEachFast(inputData, gradOut, gradIn, static_cast(channel), + [&mean, &k, iw, gradMean](const DType *inputDataPtr, + const DType *gradOut_data, + DType *gradIn_data) { + DType normal_val = (*inputDataPtr - mean) * k; + *gradIn_data += (*gradOut_data - gradMean - + normal_val) * iw; + }); + } } else { // when in evaluation mode // Q(X) = X - running_mean ; i.e. input centered to zero mean // Y = Q(X) / running_std ; i.e. BN output before weight and bias // dL/dX = w / running_std const AccReal iw = invstd * w; - ForEachFast(gradOut, gradIn, static_cast(channel), - [iw](const DType *gradOut_data, DType *gradIn_data) { - *gradIn_data = *gradOut_data * iw; - }); + if (req[batchnorm::kData] != kAddTo) { + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data = *gradOut_data * iw; + }); + } else { + ForEachFast(gradOut, gradIn, static_cast(channel), + [iw](const DType *gradOut_data, DType *gradIn_data) { + *gradIn_data += *gradOut_data * iw; + }); + } } } // May want to make this a param eventually const AccReal scale = 1.0f; - if (IsBNWriting(req[batchnorm::kGamma])) { - if (!param_.fix_gamma) { - gradWeightData[channel] = scale * dotp * invstd; - } else { + if (!param_.fix_gamma) { + KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd); + } else { + if (IsBNWriting(req[batchnorm::kGamma])) { gradWeightData[channel] = AccReal(0); } } - if (IsBNWriting(req[batchnorm::kBeta])) { - gradBiasData[channel] = scale * sumGradOut; - } + KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut); } } diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index be9309c8bfb1..7b36d25e7496 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -34,6 +34,9 @@ #define FIX_GAMMA_FLAG 8 #define IS_TRAINING_FLAG 16 #define USE_GLOBAL_STATS_FLAG 32 +#define ADDTO_DATA_FLAG (1 << 6) +#define ADDTO_GAMMA_FLAG (1 << 7) +#define ADDTO_BETA_FLAG (1 << 8) #if MXNET_USE_CUDNN == 1 #include "./cudnn/cudnn_batch_norm-inl.h" @@ -362,33 +365,60 @@ static __global__ void BatchNormalizationBackwardKernel( * momentum + localVariance * (AccReal(1) - momentum); } - if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) { - for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { - for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { - const DType gradOut = gradOutput.get_ref(batch, plane, x); - if (is_train_and_not_global_stats) { - const DType inp = input.get_ref(batch, plane, x); - const AccReal proj = (inp - mean) * projScale; - gradInput.get_ref(batch, plane, x) = - ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput.get_ref(batch, plane, x) = ScalarConvert::to( - gradOut * gradScale); + if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) { + const bool grad_write = flags & WRITE_DATA_FLAG; + if (grad_write) { + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) = + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) = ScalarConvert::to( + gradOut * gradScale); + } + } + } + } else { + // grad addto + for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) { + for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) { + const DType gradOut = gradOutput.get_ref(batch, plane, x); + if (is_train_and_not_global_stats) { + const DType inp = input.get_ref(batch, plane, x); + const AccReal proj = (inp - mean) * projScale; + gradInput.get_ref(batch, plane, x) += + ScalarConvert::to((gradOut - proj - gradMean) * gradScale); + } else { + gradInput.get_ref(batch, plane, x) += ScalarConvert::to( + gradOut * gradScale); + } } } } } - if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) { + if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && + (flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) { if ((flags & FIX_GAMMA_FLAG) == 0) { - tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + if (flags & WRITE_GAMMA_FLAG) + tensors.gradWeight[plane] = ScalarConvert::to(dotP * invstd); + else + tensors.gradWeight[plane] += ScalarConvert::to(dotP * invstd); } else { tensors.gradWeight[plane] = DType(0); } } - if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && + (flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) { + if (flags & WRITE_BETA_FLAG) + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); + else + tensors.gradBias[plane] += ScalarConvert::to(gradOutputSum); } } @@ -585,12 +615,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx, flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0; if (IsBNWriting(req[batchnorm::kData])) { flags |= WRITE_DATA_FLAG; + } else if (req[batchnorm::kData] == kAddTo) { + flags |= ADDTO_DATA_FLAG; } if (IsBNWriting(req[batchnorm::kGamma])) { flags |= WRITE_GAMMA_FLAG; + } else if (req[batchnorm::kGamma] == kAddTo) { + flags |= ADDTO_GAMMA_FLAG; } if (IsBNWriting(req[batchnorm::kBeta])) { flags |= WRITE_BETA_FLAG; + } else if (req[batchnorm::kBeta] == kAddTo) { + flags |= ADDTO_BETA_FLAG; } return flags; } diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 881d3d2247da..fc91212fab37 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -220,13 +220,24 @@ class CuDNNBatchNormOp { if (param_.fix_gamma) gamma = 1.f; + bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) || + (req[cudnnbatchnorm::kBeta] == kAddTo); + if (grad_add_gamma_beta) { + if (IsBNWriting(req[cudnnbatchnorm::kGamma])) { + dgamma = 0.f; + } + if (IsBNWriting(req[cudnnbatchnorm::kBeta])) { + dbeta = 0.f; + } + } + CUDNN_CALL(cudnnBatchNormalizationBackward( s->dnn_handle_, mode, &a, - &b, + req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b, &a, - req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add, + grad_add_gamma_beta ? &b_add : &b, // gamma and beta io_desc_, x.dptr_, io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index d407d941a03b..2021ba02c144 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -326,7 +326,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, else if (diff.IsDefaultData()) diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); + auto gradi_mem = CreateMKLDNNMem(const_cast(gradIn), + bwd.pd.diff_src_desc(), req[batchnorm::kData]); if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; @@ -347,7 +348,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } mkldnn_args_map_t net_args; net_args[MKLDNN_ARG_SRC] = *data_mem; - net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second; net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; @@ -372,28 +373,46 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, } net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = var_mem; - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - MKLDNNStream::Get()->Submit(); } else { net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); - MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); - MKLDNNStream::Get()->Submit(); } + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); + CommitOutput(gradIn, gradi_mem); + MKLDNNStream::Get()->Submit(); // copy data from gradw_mem to in_grad[1] and in_grad[2] DType *gw_buf = reinterpret_cast(bwd.GetGradw().get_data_handle()); - DType *w_grad_1 = in_grad[1].data().dptr(); - DType *w_grad_2 = in_grad[2].data().dptr(); + DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr(); + DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr(); + // the gradient of gamma if (!param.fix_gamma) { - memcpy(w_grad_1, gw_buf, copy_size); - memcpy(w_grad_2, &gw_buf[channels_], copy_size); + if (req[batchnorm::kGamma] != kNullOp) { + if (req[batchnorm::kGamma] != kAddTo) { + memcpy(w_grad_1, gw_buf, copy_size); + } else { + for (int i = 0; i < channels_; i++) { + w_grad_1[i] += gw_buf[i]; + } + } + } } else { for (int i = 0; i < channels_; i++) { (in_grad[1].data().dptr())[i] = 0.0f; } - memcpy(w_grad_2, &gw_buf[channels_], copy_size); + } + + // the gradient of beta + if (req[batchnorm::kBeta] != kNullOp) { + if (req[batchnorm::kBeta] != kAddTo) { + memcpy(w_grad_2, &gw_buf[channels_], copy_size); + } else { + DType *grad_beta = &gw_buf[channels_]; + for (int i = 0; i < channels_; i++) { + w_grad_2[i] += grad_beta[i]; + } + } } } else { LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ..."; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 52d52893a9a7..15e9bd4e4b9a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1402,6 +1402,163 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, transpose_b=transpose_b)) +@with_seed() +@use_np +def test_npx_batch_norm(): + momentum = 0.9 + epsilon = 1e-5 + class TestBatchNorm(HybridBlock): + def __init__(self, eps=1e-5, fix_gamma=False, momentum=0.9, **kwargs): + super().__init__() + self.eps = eps + self.fix_gamma = fix_gamma + self.momentum = momentum + self.kwargs = kwargs + def hybrid_forward(self, F, data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var): + op = F.npx.batch_norm + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var, + momentum=self.momentum, eps=self.eps, + fix_gamma=self.fix_gamma, **self.kwargs) + return output + + def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var, + axis, + data_grad_req, gamma_grad_req, beta_grad_req): + kwargs = dict(output_mean_var=output_mean_var) + kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) + op = TestBatchNorm(eps=epsilon, fix_gamma=fix_gamma, momentum=momentum, **kwargs) + nch = shape[axis] + + if not fix_gamma: + bn_gamma = np.random.uniform(size=(nch,)) + bn_gamma.attach_grad(grad_req=gamma_grad_req) + else: + bn_gamma = np.ones((nch,)) + + bn_beta = np.random.uniform(size=(nch,)) + bn_beta.attach_grad(grad_req=beta_grad_req) + + bn_running_mean = np.zeros(nch) + bn_running_var = np.ones(nch) + + running_mean = np.zeros(nch) + running_var = np.ones(nch) + num_iters = 10 + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + expand_shape = tuple(expand_shape) + data = np.random.uniform(size=shape) + data.attach_grad(grad_req=data_grad_req) + adX, adW, adb = 0, 0, 0 + is_train = data_grad_req != 'null' or \ + (not fix_gamma and gamma_grad_req != 'null') or \ + beta_grad_req != 'null' + for _ in range(num_iters): + if data_grad_req != 'add': + data = np.random.uniform(size=shape) + data.attach_grad(grad_req=data_grad_req) + ograd = np.random.uniform(size=shape) + with mx.autograd.record(): + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var) + if output_mean_var: + output, output_mean, output_std = output + if is_train: + output.backward(ograd) + mx.nd.waitall() + + assert 0 <= axis < data.ndim + reduce_axis = tuple(i for i in range(data.ndim) if i != axis) + assert len(reduce_axis) == data.ndim - 1 + data_mean = data.mean( + axis=reduce_axis, keepdims=True) + data_var = ((data - data_mean) ** 2).mean(axis=reduce_axis, + keepdims=True) + + target_output = (data - data_mean) / \ + np.sqrt(data_var + epsilon) * \ + bn_gamma.reshape(expand_shape) + \ + bn_beta.reshape(expand_shape) + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + W = bn_gamma.reshape(expand_shape) + dnx = ograd * W + xsm = data - data_mean + nd = 1.0 / np.sqrt(data_var + epsilon) + nx = xsm * nd + m = _np.prod(shape) / shape[axis] + dvar = np.sum(dnx * xsm, axis=reduce_axis, keepdims=True, + ) * (-0.5) * np.power(nd, 3) + dmean = -nd * np.sum(dnx, axis=reduce_axis, keepdims=True) - \ + dvar * xsm.mean(axis=reduce_axis, keepdims=True, + ) * 2.0 + dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) + dW = np.sum(ograd * nx, axis=reduce_axis) + db = np.sum(ograd, axis=reduce_axis) + adX = dX if data_grad_req != 'add' else adX + dX + adW = dW if gamma_grad_req != 'add' else adW + dW + adb = db if beta_grad_req != 'add' else adb + db + + atol, rtol = 5e-2, 5e-2 + + if output_mean_var: + assert_almost_equal(output_mean.asnumpy(), + data_mean_flat.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output_std.asnumpy(), + (1.0 / np.sqrt(data_var_flat + + epsilon)).asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + if is_train: + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + if data_grad_req != 'null': + assert_almost_equal(data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) + if not fix_gamma: + if gamma_grad_req != 'null': + assert_almost_equal( + bn_gamma.grad.asnumpy(), adW.asnumpy(), + atol=atol, rtol=rtol) + else: + assert((bn_gamma.asnumpy() == 1).all()) + if beta_grad_req != 'null': + assert_almost_equal( + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) + + shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)] + bools = [False, True] + for shape, fix_gamma, cudnn_off, output_mean_var in itertools.product( + shapes, bools, bools, bools): + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl( + shape, fix_gamma, cudnn_off, output_mean_var, + axis, + data_grad_req, + gamma_grad_req, beta_grad_req) + + @with_seed() @use_np def test_npi_boolean_assign(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c73b8456240b..a2739a361034 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1827,11 +1827,18 @@ def test_batchnorm(): momentum = 0.9 epsilon = 1e-5 - def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): - print(str((op, shape, axis, cudnn_off))) - + def _test_batchnorm_impl(op_name, shape, fix_gamma, cudnn_off, output_mean_var, + axis, + data_grad_req, gamma_grad_req, beta_grad_req): + + if op_name == 'BatchNorm': + op = mx.nd.BatchNorm + elif op_name == 'SyncBatchNorm': + op = mx.nd.contrib.SyncBatchNorm + else: + raise ValueError('Not supported {}'.format(op_name)) kwargs = dict(output_mean_var=output_mean_var) - if op == mx.nd.contrib.SyncBatchNorm: + if op_name == 'SyncBatchNorm': if axis != 1: return key = str(op) + str(shape) + str(axis) @@ -1842,11 +1849,14 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) nch = shape[axis] - bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad() + if not fix_gamma: + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad(grad_req=gamma_grad_req) + else: + bn_gamma = mx.nd.ones(shape=(nch,)) bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad() + bn_beta.attach_grad(grad_req=beta_grad_req) bn_running_mean = mx.nd.zeros(nch) bn_running_var = mx.nd.ones(nch) @@ -1856,18 +1866,26 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): num_iters = 10 expand_shape = [1] * len(shape) expand_shape[axis] = shape[axis] + data = mx.nd.random.uniform(shape=shape) + data.attach_grad(grad_req=data_grad_req) + adX, adW, adb = 0, 0, 0 + is_train = data_grad_req != 'null' or \ + (not fix_gamma and gamma_grad_req != 'null') or \ + beta_grad_req != 'null' for _ in range(num_iters): - data = mx.nd.random.uniform(shape=shape) - data.attach_grad() + if data_grad_req != 'add': + data = mx.nd.random.uniform(shape=shape) + data.attach_grad(grad_req=data_grad_req) ograd = mx.nd.random.uniform(shape=shape) with mx.autograd.record(): output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, momentum=momentum, eps=epsilon, - fix_gamma=False, **kwargs) + fix_gamma=fix_gamma, **kwargs) if output_mean_var: output, output_mean, output_std = output - output.backward(ograd) + if is_train: + output.backward(ograd) mx.nd.waitall() data_mean = data.mean( @@ -1904,9 +1922,11 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) + adX = dX if data_grad_req != 'add' else adX + dX + adW = dW if gamma_grad_req != 'add' else adW + dW + adb = db if beta_grad_req != 'add' else adb + db - atol = 1e-2 - rtol = 1e-2 + atol, rtol = 5e-2, 5e-2 if output_mean_var: assert_almost_equal(output_mean.asnumpy(), @@ -1923,26 +1943,43 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): atol=atol, rtol=rtol) assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=atol, rtol=rtol) - - assert_almost_equal(data.grad.asnumpy(), - dX.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal( - bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) - - for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]: - for axis in range(len(shape)): - for cudnn_off in [False, True]: - for output_mean_var in [False, True]: - _test_batchnorm_impl(op, shape, axis, - cudnn_off, output_mean_var) - + if is_train: + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + if data_grad_req != 'null': + assert_almost_equal(data.grad.asnumpy(), + adX.asnumpy(), atol=atol, rtol=rtol) + if not fix_gamma: + if gamma_grad_req != 'null': + assert_almost_equal( + bn_gamma.grad.asnumpy(), adW.asnumpy(), + atol=atol, rtol=rtol) + else: + assert((bn_gamma.asnumpy() == 1).all()) + if beta_grad_req != 'null': + assert_almost_equal( + bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) + + op_names = ['BatchNorm', 'SyncBatchNorm'] + shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)] + bools = [False, True] + for op_name, shape, fix_gamma, cudnn_off, output_mean_var in itertools.product( + op_names, shapes, bools, bools, bools): + grad_reqs = ['write'] if len(shape) != 4 else ['null', 'write', 'add'] + for data_grad_req in grad_reqs: + for gamma_grad_req in grad_reqs: + if fix_gamma and gamma_grad_req != 'null': + continue + for beta_grad_req in grad_reqs: + for axis in range(len(shape)): + _test_batchnorm_impl( + op_name, shape, fix_gamma, cudnn_off, output_mean_var, + axis, + data_grad_req, + gamma_grad_req, beta_grad_req) @with_seed() def test_groupnorm():