Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN LRN #16223

Merged
merged 1 commit into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include "./lrn-inl.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_lrn-inl.h"
#include "./mkldnn/mkldnn_base-inl.h"
#endif
Expand Down Expand Up @@ -82,7 +82,7 @@ struct LRNGrad {
}
};

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -169,7 +169,7 @@ number of kernels in the layer.
.set_attr_parser(ParamParser<LRNParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LRNShape)
.set_attr<nnvm::FInferType>("FInferType", LRNType)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", LRNForwardInferStorageType)
#endif
.set_attr<nnvm::FListInputNames>("FListInputNames",
Expand All @@ -181,7 +181,7 @@ number of kernels in the layer.
return std::vector<std::string>{"output", "tmp_norm"};
})
.set_attr<FCompute>("FCompute<cpu>", LRNCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNComputeExCPU)
#endif
Expand All @@ -192,11 +192,11 @@ number of kernels in the layer.
NNVM_REGISTER_OP(_backward_LRN)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LRNParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", LRNBackwardInferStorageType)
#endif
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNGradComputeExCPU)
// Native compute requires norm while MKLDNN does not so cannot be compared in debug mode
Expand Down
192 changes: 70 additions & 122 deletions src/operator/nn/mkldnn/mkldnn_lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <utility>
#include <mkldnn.hpp>
#include "../lrn-inl.h"
Expand All @@ -34,27 +34,27 @@
namespace mxnet {
namespace op {

inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
inline mkldnn::algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
// TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward
// Need to confirm with MKLDNN team and fix later
return algorithm::lrn_across_channels;
return mkldnn::algorithm::lrn_across_channels;
}

inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc(
const LRNParam &param, const bool is_train, const memory::desc &src_md) {
const LRNParam &param, const bool is_train, const mkldnn::memory::desc &src_md) {
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
const algorithm alg = GetMKLDNNLRNAlgo(param);
const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
const float beta = param.beta;
const int nsize = param.nsize;
const float k = param.knorm;
auto kind = prop_kind::forward_training;
auto kind = mkldnn::prop_kind::forward_training;
if (is_train) {
kind = prop_kind::forward_training;
kind = mkldnn::prop_kind::forward_training;
} else {
kind = prop_kind::forward_scoring;
kind = mkldnn::prop_kind::forward_scoring;
}
lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
mkldnn::lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
}

Expand All @@ -63,13 +63,13 @@ inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
const mkldnn::memory::desc &diff_md,
const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
const algorithm alg = GetMKLDNNLRNAlgo(param);
const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
const float beta = param.beta;
const int nsize = param.nsize;
const float k = param.knorm;

lrn_backward::desc lrnBwd_desc(alg, data_in_md,
mkldnn::lrn_backward::desc lrnBwd_desc(alg, data_in_md,
diff_md, nsize, alpha, beta, k);
return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc,
engine, lrnFwd_desc);
Expand All @@ -83,33 +83,24 @@ class MKLDNNLRNFwd {
public:
MKLDNNLRNFwd(const LRNParam& param,
bool is_train,
const NDArray &in_data):
is_train(is_train) {
const NDArray &in_data) {
_Init(param, is_train, in_data);
}

~MKLDNNLRNFwd() {}

void SetNewMem(const NDArray &data,
const NDArray &output,
const OpReqType req);

void SetNewMem(const NDArray &in_data,
const mkldnn::memory *out_mem);

void Execute(const NDArray &out_data);
void Execute(const OpContext &ctx,
const NDArray &in_data,
const OpReqType req,
const NDArray &out_data);

mkldnn::lrn_forward &GetFwd();

const mkldnn::memory *GetWs();
mkldnn::lrn_forward::primitive_desc &GetFwdPd();

private:
std::shared_ptr<mkldnn::lrn_forward> fwd;
std::shared_ptr<mkldnn::memory> in_mem;
std::shared_ptr<mkldnn::memory> out_mem;
std::shared_ptr<mkldnn::memory> ws_mem;
mkldnn_output_t output_mem_t;
bool is_train;
mkldnn::lrn_forward::primitive_desc fwd_pd;

private:
void _Init(const LRNParam &param, bool is_train, const NDArray &in_data);
Expand All @@ -119,52 +110,37 @@ void MKLDNNLRNFwd::_Init(const LRNParam &param,
bool is_train,
const NDArray &in_data) {
mkldnn::memory::desc in_data_md =
in_data.GetMKLDNNData()->get_primitive_desc().desc();
mkldnn::lrn_forward::primitive_desc fwd_pd =
in_data.GetMKLDNNData()->get_desc();
this->fwd_pd =
GetLRNFwdDesc(param, is_train, in_data_md);

this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
->get_primitive_desc()));
this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
if (is_train) {
// If it's training, we have to create a workspace memory. Otherwise,
// MKLDNN will have segmentation fault.
ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
*this->ws_mem, *this->out_mem));
} else {
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)),
*(this->out_mem)));
}
}

void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
const NDArray &out_data,
const OpReqType req) {
const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
output_mem_t = CreateMKLDNNMem(out_data, this->out_mem->get_primitive_desc(), req);
this->in_mem->set_data_handle(in_data_mem->get_data_handle());
this->out_mem->set_data_handle(output_mem_t.second->get_data_handle());
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(new mkldnn::lrn_forward(this->fwd_pd));
}

void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
const mkldnn::memory *out_mem) {
const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
this->in_mem->set_data_handle(in_data_mem->get_data_handle());
this->out_mem->set_data_handle(out_mem->get_data_handle());
}

void MKLDNNLRNFwd::Execute(const NDArray &out_data) {
MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
void MKLDNNLRNFwd::Execute(const OpContext &ctx,
const NDArray &in_data,
const OpReqType req,
const NDArray &out_data) {
auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req);

mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
{ MKLDNN_ARG_DST, *output_mem_t.second },
};
std::shared_ptr<mkldnn::memory> workspace;
if (ctx.is_train) {
auto engine = CpuEngine::Get()->get_engine();
workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
args[MKLDNN_ARG_WORKSPACE] = *(workspace);
}
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd), args);
CommitOutput(out_data, output_mem_t);
MKLDNNStream::Get()->Submit();
}

mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; }
mkldnn::lrn_forward::primitive_desc &MKLDNNLRNFwd::GetFwdPd() { return this->fwd_pd; }

const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); }
// End of LRN Class and its functions

static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
Expand All @@ -180,10 +156,11 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
OpHash> lrn_fwds;
#endif
auto kind_ =
ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
ctx.is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;

MKLDNNLRNSignature key(param);
key.AddSign(kind_);
key.AddSign(static_cast<int>(kind_));
key.AddSign(in_data);

auto it = lrn_fwds.find(key);
Expand All @@ -201,17 +178,12 @@ void MKLDNNLRNForward(const OpContext &ctx, const LRNParam &param,
if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
in_buffer = in_buffer.Reorder2Default();
MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
fwd.SetNewMem(in_buffer, out_data, req);
fwd.Execute(out_data);
fwd.Execute(ctx, in_buffer, req, out_data);
}

// LRN Backward Class
class MKLDNNLRNBwd {
std::shared_ptr<mkldnn::lrn_backward> bwd;
std::shared_ptr<mkldnn::memory> in_data_mem;
std::shared_ptr<mkldnn::memory> diff_dst_mem;
std::shared_ptr<mkldnn::memory> ws_mem;
std::shared_ptr<mkldnn::memory> diff_src_mem;

public:
const mkldnn::lrn_forward::primitive_desc fwd_pd;
Expand All @@ -222,40 +194,26 @@ class MKLDNNLRNBwd {
MKLDNNLRNBwd(const LRNParam &param, const mkldnn::memory::desc in_data_md,
const mkldnn::memory::desc diff_md)
: fwd_pd(GetLRNFwdDesc(param, true, in_data_md)),
bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {}

void SetNewMem(const NDArray &in_data, const NDArray &out_grad,
const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
if (bwd == nullptr) {
this->in_data_mem.reset(
new mkldnn::memory(this->fwd_pd.src_primitive_desc(),
in_data.GetMKLDNNData()->get_data_handle()));
this->diff_dst_mem.reset(
new mkldnn::memory(this->fwd_pd.dst_primitive_desc(),
out_grad.GetMKLDNNData()->get_data_handle()));
this->ws_mem.reset(
new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(),
ws->get_data_handle()));
this->diff_src_mem.reset(
new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(),
diff_src_mem->get_data_handle()));
this->bwd.reset(new mkldnn::lrn_backward(
this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem),
mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem,
*this->diff_src_mem));
} else {
this->in_data_mem->set_data_handle(
in_data.GetMKLDNNData()->get_data_handle());
this->diff_dst_mem->set_data_handle(
out_grad.GetMKLDNNData()->get_data_handle());
this->ws_mem->set_data_handle(ws->get_data_handle());
this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle());
}
}

void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
MKLDNNStream::Get()->RegisterPrim(*(this->bwd));
CommitOutput(in_grad, diff_src_mem_);
bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {
bwd = std::make_shared<mkldnn::lrn_backward>(bwd_pd);
}

const mkldnn::lrn_backward &GetBwd() const { return *bwd; }

void Execute(const NDArray &out_grad,
const NDArray &in_data,
const NDArray &in_grad,
const mkldnn_output_t &diff_src_mem) {
auto engine = CpuEngine::Get()->get_engine();
auto workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *in_data.GetMKLDNNData() },
{ MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()},
{ MKLDNN_ARG_WORKSPACE, *workspace },
{ MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }
};
MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args);
CommitOutput(in_grad, diff_src_mem);
MKLDNNStream::Get()->Submit();
}
}; // End of LRN Class
Expand All @@ -277,9 +235,9 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
auto it = lrn_bwds.find(key);
if (it == lrn_bwds.end()) {
const mkldnn::memory::desc in_data_md =
in_data.GetMKLDNNData()->get_primitive_desc().desc();
in_data.GetMKLDNNData()->get_desc();
const mkldnn::memory::desc diff_md =
out_grad.GetMKLDNNData()->get_primitive_desc().desc();
out_grad.GetMKLDNNData()->get_desc();
MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
it = AddToCache(&lrn_bwds, key, bwd);
}
Expand All @@ -300,23 +258,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
in_buffer = in_data.Reorder2Default();
}
MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
// Repeat FW for getting workspace
// TODO(Patric): To keep the function stateless, we can't pass workspace
// from LRN forward to backward. We have to re-compute
// LRN forward to get the workspace.
// Will refine this code later.
MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
std::shared_ptr<const mkldnn::memory> dst_temp(
new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc()));
fwd.SetNewMem(in_buffer, dst_temp.get());
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());

mkldnn_output_t diff_src_mem =
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req);
bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second);
bwd.Execute(in_grad, diff_src_mem);
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req);

bwd.Execute(out_grad, in_buffer, in_grad, diff_src_mem);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__