From 2c3472fe77aa0249c7a2d716ee89e6946d01ebf1 Mon Sep 17 00:00:00 2001 From: wuxun-zhang Date: Mon, 1 Jul 2019 09:35:08 +0800 Subject: [PATCH] address comments --- src/operator/nn/mkldnn/mkldnn_ops-inl.h | 6 ++ src/operator/nn/mkldnn/mkldnn_reshape-inl.h | 6 +- src/operator/nn/mkldnn/mkldnn_reshape.cc | 77 ++++++++++----------- src/operator/tensor/matrix_op.cc | 1 - 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 74133e58432c..502abff6231b 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -119,6 +119,12 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpReqType &req, const NDArray &output); +void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); + void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const NDArray &input, diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h index 73a6108fefd0..1f052714e79a 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h @@ -72,11 +72,7 @@ MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, const OpReqType &req, const NDArray &input, const NDArray &output); -void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index f226640274d0..3e6e060ac969 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -26,7 +26,6 @@ #if MXNET_USE_MKLDNN == 1 #include -#include "../../tensor/matrix_op-inl.h" #include "mkldnn_reshape-inl.h" namespace mxnet { @@ -35,46 +34,46 @@ namespace op { MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, const NDArray &input, const NDArray &output) { - auto engine = CpuEngine::Get()->get_engine(); - - // data_ - auto in_mem = input.GetMKLDNNData(); - auto in_pd = in_mem->get_primitive_desc(); - data_ = std::make_shared(in_pd, nullptr); - - // temp_ - auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); - auto temp_type = static_cast(in_pd.desc().data.data_type); - auto temp_fmt = static_cast(GetDefaultFormat(in_pd.desc())); - auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); - auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); - temp_ = std::make_shared(temp_pd, nullptr); - - // destination - out_ = std::make_shared(temp_pd, nullptr); - - if (req == kWriteInplace) { - // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with - // default layout and copy from the temporal buffer back to output buffer which has the same - // address with input buffer. - // If the input has default layout, then nothing need to do. - if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back - needInvalidateInput = true; - } - } else if (req == kWriteTo) { - if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer - needInvalidateInput = false; - } else { - prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output - needInvalidateInput = false; - } + auto engine = CpuEngine::Get()->get_engine(); + + // data_ + auto in_mem = input.GetMKLDNNData(); + auto in_pd = in_mem->get_primitive_desc(); + data_ = std::make_shared(in_pd, nullptr); + + // temp_ + auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); + auto temp_type = static_cast(in_pd.desc().data.data_type); + auto temp_fmt = static_cast(GetDefaultFormat(in_pd.desc())); + auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); + auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); + temp_ = std::make_shared(temp_pd, nullptr); + + // destination + out_ = std::make_shared(temp_pd, nullptr); + + if (req == kWriteInplace) { + // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with + // default layout and copy from the temporal buffer back to output buffer which has the same + // address with input buffer. + // If the input has default layout, then nothing need to do. + if (input.IsMKLDNNData()) { + prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back + needInvalidateInput = true; + } + } else if (req == kWriteTo) { + if (input.IsMKLDNNData()) { + prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer + needInvalidateInput = false; } else { - LOG(FATAL) << "not supported req type: " << req; + prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output + needInvalidateInput = false; } + } else { + LOG(FATAL) << "not supported req type: " << req; + } } int MKLDNNReshapeFwd::GetWorkspaceSize() { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index e96834ffbe84..ffae00ce278b 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -28,7 +28,6 @@ #include "../nn/mkldnn/mkldnn_ops-inl.h" #include "../nn/mkldnn/mkldnn_base-inl.h" #include "../nn/mkldnn/mkldnn_slice-inl.h" -#include "../nn/mkldnn/mkldnn_reshape-inl.h" namespace mxnet { namespace op {