Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn]Mkldnn bn opt backport from master to 1.7x (#18009)
Browse files Browse the repository at this point in the history
* optimize for backward batchnorm

* using memcpy instead of 'for' loop

* rm unnecessary pointer cast and add const for some variable

* trigger CI
  • Loading branch information
rongzha1 authored Apr 15, 2020
1 parent 6fa374b commit 50d6d7d
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
float* weight_ptr = gamma.data().dptr<float>();
float* bias_ptr = beta.data().dptr<float>();
const size_t copy_size = sizeof(weight_buf[0]) * channels_;
if (!param.fix_gamma) {
memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_);
memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_);
memcpy(weight_buf, weight_ptr, copy_size);
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
} else if (IsBNWriting(req[batchnorm::kGamma])) {
for (int i = 0; i < channels_; i++) {
weight_buf[i] = 1.0f;
Expand Down Expand Up @@ -332,17 +333,18 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const NDArray &beta = in_data[batchnorm::kBeta];
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
weight_buf[i] = (gamma.data().dptr<DType>())[i]; // weight
else
DType *weight_ptr = gamma.data().dptr<DType>();
DType* bias_ptr = beta.data().dptr<DType>();
const size_t copy_size = sizeof(DType) * channels_;
if (!param.fix_gamma) {
memcpy(weight_buf, weight_ptr, copy_size);
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
} else {
for (int i = 0; i < channels_; i++) {
weight_buf[i] = static_cast<DType>(1.0f);
}
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
}

for (int i = 0; i < channels_; i++) {
weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i]; // bias
}

mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
Expand All @@ -352,10 +354,10 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,

// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
DType* moving_mean_ptr = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
DType* moving_var_ptr = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
DType* out_mean_ptr = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
DType* out_var_ptr = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
DType* moving_mean_ptr = moving_mean.data().dptr<DType>();
DType* moving_var_ptr = moving_var.data().dptr<DType>();
DType* out_mean_ptr = out_mean.data().dptr<DType>();
DType* out_var_ptr = out_var.data().dptr<DType>();
mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine());
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());

Expand All @@ -381,15 +383,17 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,

// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
(in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
else
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
}
DType *w_grad_1 = in_grad[1].data().dptr<DType>();
DType *w_grad_2 = in_grad[2].data().dptr<DType>();

for (int i = 0; i < channels_; i++) {
(in_grad[2].data().dptr<DType>())[i] = gw_buf[i + channels_];
if (!param.fix_gamma) {
memcpy(w_grad_1, gw_buf, copy_size);
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
} else {
for (int i = 0; i < channels_; i++) {
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
}
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
}
} else {
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
Expand Down

0 comments on commit 50d6d7d

Please sign in to comment.