Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxun-zhang committed Jul 1, 2019
1 parent 436ffa4 commit 2c3472f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
6 changes: 6 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpReqType &req,
const NDArray &output);

void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);

void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
Expand Down
6 changes: 1 addition & 5 deletions src/operator/nn/mkldnn/mkldnn_reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
const NDArray &output);
void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);

} // namespace op
} // namespace mxnet

Expand Down
77 changes: 38 additions & 39 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#if MXNET_USE_MKLDNN == 1

#include <mkldnn.hpp>
#include "../../tensor/matrix_op-inl.h"
#include "mkldnn_reshape-inl.h"

namespace mxnet {
Expand All @@ -35,46 +34,46 @@ namespace op {
MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
const NDArray &input,
const NDArray &output) {
auto engine = CpuEngine::Get()->get_engine();

// data_
auto in_mem = input.GetMKLDNNData();
auto in_pd = in_mem->get_primitive_desc();
data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);

// temp_
auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);

// destination
out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);

if (req == kWriteInplace) {
// If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
// default layout and copy from the temporal buffer back to output buffer which has the same
// address with input buffer.
// If the input has default layout, then nothing need to do.
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
needInvalidateInput = true;
}
} else if (req == kWriteTo) {
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
needInvalidateInput = false;
} else {
prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output
needInvalidateInput = false;
}
auto engine = CpuEngine::Get()->get_engine();

// data_
auto in_mem = input.GetMKLDNNData();
auto in_pd = in_mem->get_primitive_desc();
data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);

// temp_
auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);

// destination
out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);

if (req == kWriteInplace) {
// If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
// default layout and copy from the temporal buffer back to output buffer which has the same
// address with input buffer.
// If the input has default layout, then nothing need to do.
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
needInvalidateInput = true;
}
} else if (req == kWriteTo) {
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
needInvalidateInput = false;
} else {
LOG(FATAL) << "not supported req type: " << req;
prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output
needInvalidateInput = false;
}
} else {
LOG(FATAL) << "not supported req type: " << req;
}
}

int MKLDNNReshapeFwd::GetWorkspaceSize() {
Expand Down
1 change: 0 additions & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"
#include "../nn/mkldnn/mkldnn_slice-inl.h"
#include "../nn/mkldnn/mkldnn_reshape-inl.h"

namespace mxnet {
namespace op {
Expand Down

0 comments on commit 2c3472f

Please sign in to comment.