Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN Pooling (#16272)
Browse files Browse the repository at this point in the history
* add mkldnn pooling

* add workaround for mkldnn v1.0 pooling fwd && bwd workspace mismatch

* code clean

* fix lint error

* trigger CI

* trigger CI

* add extra work_space check and fix some typo

* trigger CI
  • Loading branch information
rongzha1 authored and TaoLv committed Oct 10, 2019
1 parent 3706ece commit 458bb73
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 145 deletions.
40 changes: 14 additions & 26 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <utility>
#include <mkldnn.hpp>
Expand All @@ -43,60 +43,48 @@ class MKLDNNPoolingFwd {
const int padding_t, const int padding_b,
const int padding_l, const int padding_r,
const mkldnn::algorithm alg_kind,
const bool with_workspace, const bool is_train) :
is_train_(is_train),
const bool with_workspace, const bool is_train):
with_workspace_(with_workspace),
alg_kind_(alg_kind),
fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) {
fwd_(nullptr) {
Init(input, output,
kernel_h, kernel_w, stride_h, stride_w,
padding_t, padding_b, padding_l, padding_r);
padding_t, padding_b, padding_l, padding_r,
is_train, alg_kind);
}

~MKLDNNPoolingFwd() {}
void SetNewMem(const NDArray& in_data,
const NDArray& out_data,
const OpReqType& req,
const mxnet::NDArray *workspace = nullptr);
void Execute(const NDArray& out_data);
void Execute(const NDArray &in_data,
const OpReqType req,
const NDArray& out_data,
const NDArray *workspace);

private:
bool is_train_;
bool with_workspace_;
mkldnn::algorithm alg_kind_;

std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::pooling_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> workspace_;
mkldnn_output_t output_mem_t_;

private:
void Init(const mxnet::NDArray &input,
const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r);
const int padding_l, const int padding_r,
const bool is_train, const mkldnn::algorithm alg_kind);
};

class MKLDNNPoolingBwd {
std::shared_ptr<const mkldnn::pooling_backward> bwd;
std::shared_ptr<mkldnn::memory> diff_dst;
std::shared_ptr<mkldnn::memory> diff_src;
std::shared_ptr<mkldnn::memory> ws;
bool with_workspace;

public:
const mkldnn::pooling_backward::primitive_desc pd;

MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc,
MKLDNNPoolingBwd(const mkldnn::pooling_backward::primitive_desc &pdesc,
bool with_ws);

~MKLDNNPoolingBwd() {}
void SetNewMem(const mxnet::NDArray *workspace,
const mxnet::NDArray &out_grad,
const mkldnn::memory *diff_src_mem);
const mkldnn::pooling_backward &GetBwd();
const mkldnn::pooling_backward::primitive_desc &GetPd();
};
Expand Down Expand Up @@ -141,5 +129,5 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
const NDArray &output);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_
175 changes: 71 additions & 104 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \author Tao Lv
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include "./mkldnn_pooling-inl.h"

Expand All @@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r) {
// mkldnn::memory::desc
auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc();
const int padding_l, const int padding_r,
const bool is_train, const mkldnn::algorithm alg_kind) {
auto src_md = input.GetMKLDNNData()->get_desc();
mkldnn::memory::dims dims = {src_md.data.dims[0],
src_md.data.dims[1],
static_cast<int>(output.shape()[2]),
static_cast<int>(output.shape()[3])};
auto dst_md = mkldnn::memory::desc({dims},
static_cast<mkldnn::memory::data_type>(src_md.data.data_type),
static_cast<mkldnn::memory::format>(src_md.data.format));
mkldnn::memory::format_tag::any);
const mkldnn::engine engine = CpuEngine::Get()->get_engine();
const mkldnn::algorithm alg_kind = this->alg_kind_;
if (alg_kind != mkldnn::algorithm::pooling_max &&
alg_kind != mkldnn::algorithm::pooling_avg &&
alg_kind != mkldnn::algorithm::pooling_avg_include_padding &&
Expand All @@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
}

mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring;
if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) {
if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) {
prop = mkldnn::prop_kind::forward_training;
}
if (this->is_train_ && prop == mkldnn::prop_kind::forward_scoring) {
if (is_train && prop == mkldnn::prop_kind::forward_scoring) {
LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring";
}

Expand All @@ -67,49 +66,38 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
const mkldnn::memory::dims kernel = {kernel_h, kernel_w };
// mkldnn::pooling_forward::desc
const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md,
strides, kernel, pad_l, pad_r,
mkldnn::padding_kind::zero);
strides, kernel, pad_l, pad_r);
this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine));
this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc()));
this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc()));
if (this->with_workspace_) {
this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc()));
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_),
mkldnn::primitive::at(*(this->data_)),
*(this->out_),
*(this->workspace_)));
} else {
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_),
mkldnn::primitive::at(*(this->data_)),
*(this->out_)));
}
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_)));

return;
}

void MKLDNNPoolingFwd::SetNewMem(const NDArray& in_data,
const NDArray& out_data,
const OpReqType& req,
const mxnet::NDArray *workspace) {
auto input_mem = in_data.GetMKLDNNData();
output_mem_t_ = CreateMKLDNNMem(out_data, fwd_pd_->dst_primitive_desc(), req);
// mkldnn::memory
this->data_->set_data_handle(input_mem->get_data_handle());
this->out_->set_data_handle(output_mem_t_.second->get_data_handle());
if (this->with_workspace_ && workspace == nullptr) {
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
}
void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
const OpReqType req,
const NDArray& out_data,
const NDArray *workspace) {
NDArray in_buffer = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

auto input_mem = in_buffer.GetMKLDNNData();
auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req);

mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, *input_mem },
{MKLDNN_ARG_DST, *(output_mem_t_.second) },
};

if (this->with_workspace_) {
// mkldnn::memory
auto ws_mem = workspace->GetMKLDNNData();
this->workspace_->set_data_handle(ws_mem->get_data_handle());
auto engine = CpuEngine::Get()->get_engine();
auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
engine, workspace->GetMKLDNNData()->get_data_handle());
args[MKLDNN_ARG_WORKSPACE] = *ws;
}
}

void MKLDNNPoolingFwd::Execute(const NDArray& out_data) {
if (this->fwd_) {
MKLDNNStream::Get()->RegisterPrim(*(this->fwd_));
CommitOutput(out_data, this->output_mem_t_);
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd_), args);
CommitOutput(out_data, output_mem_t_);
MKLDNNStream::Get()->Submit();
} else {
LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr";
Expand Down Expand Up @@ -143,8 +131,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) {
}

mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
const PoolingParam &param, const bool is_train, const memory::desc &data_md,
const memory::desc &out_md) {
const PoolingParam &param, const bool is_train, const mkldnn::memory::desc &data_md,
const mkldnn::memory::desc &out_md) {
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
int kernel_h_, kernel_w_;
if (param.global_pool) {
Expand Down Expand Up @@ -183,19 +171,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(

const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
if (is_train && alg != algorithm::pooling_avg) {
if (is_train && alg != mkldnn::algorithm::pooling_avg) {
kind = mkldnn::prop_kind::forward_training;
}

const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md,
const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md,
{static_cast<int>(stride_h_),
static_cast<int>(stride_w_)},
{kernel_h_, kernel_w_},
{static_cast<int>(pad_t_),
static_cast<int>(pad_l_)},
{static_cast<int>(pad_b_),
static_cast<int>(pad_r_)},
padding_kind::zero);
static_cast<int>(pad_r_)});
return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine);
}

Expand Down Expand Up @@ -223,7 +210,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
auto it = pooling_fwds.find(key);
if (it == pooling_fwds.end()) {
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc();
auto data_md = data.GetMKLDNNData()->get_desc();
int kernel_h_, kernel_w_;
if (param.global_pool) {
kernel_h_ = data_md.data.dims[2];
Expand Down Expand Up @@ -270,42 +257,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
const NDArray &in_data, const OpReqType req,
const NDArray &out_data, const NDArray *workspace) {
auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
fwd.SetNewMem(in_data, out_data, req, workspace);
fwd.Execute(out_data);
fwd.Execute(in_data, req, out_data, workspace);
}

MKLDNNPoolingBwd::MKLDNNPoolingBwd(
const pooling_backward::primitive_desc &pdesc, bool with_ws)
: with_workspace(with_ws), pd(pdesc) {}

void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace,
const mxnet::NDArray &out_grad,
const mkldnn::memory *diff_src_mem) {
if (bwd == nullptr) {
diff_dst.reset(
new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(),
out_grad.GetMKLDNNData()->get_data_handle()));
diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(),
diff_src_mem->get_data_handle()));
if (with_workspace) {
CHECK(workspace != nullptr);
ws.reset(
new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(),
workspace->GetMKLDNNData()->get_data_handle()));
bwd.reset(
new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src));
} else {
bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src));
}
} else {
diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle());
diff_src->set_data_handle(diff_src_mem->get_data_handle());
if (with_workspace) {
CHECK(workspace != nullptr);
ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle());
const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws)
: with_workspace(with_ws), pd(pdesc) {
bwd = std::make_shared<mkldnn::pooling_backward>(pd);
}
}
}

const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() {
return *this->bwd;
Expand Down Expand Up @@ -333,27 +292,29 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,

auto it = pooling_bwds.find(key);
if (it == pooling_bwds.end()) {
auto diff_dst_mem = out_grad.GetMKLDNNData();
NDArray diff_dst_buff = out_grad;
if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) {
diff_dst_buff = out_grad.Reorder2Default();
}
auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
auto input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
const mkldnn::memory::desc data_md = data_mpd.desc();
const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
const mkldnn::memory::desc data_md = input_mem->get_desc();
const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
static_cast<int>(out_grad.shape()[2]),
static_cast<int>(out_grad.shape()[3])};
const memory::desc out_md(
{dims}, static_cast<memory::data_type>(data_md.data.data_type),
static_cast<memory::format>(data_md.data.format));
const mkldnn::memory::desc out_md(
{dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
mkldnn::memory::format_tag::any);
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);

const mkldnn::memory::desc diff_md =
diff_dst_mem->get_primitive_desc().desc();
const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
diff_dst_mem->get_desc();
const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
static_cast<int>(in_grad.shape()[2]),
static_cast<int>(in_grad.shape()[3])};
const memory::desc diff_in_md(
{dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
static_cast<memory::format>(diff_md.data.format));
const mkldnn::engine cpu_engine = data_mpd.get_engine();
const mkldnn::memory::desc diff_in_md(
{dims1}, static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
mkldnn::memory::format_tag::any);
const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);

int kernel_h_, kernel_w_;
Expand All @@ -379,11 +340,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
stride_h_ = stride_w_ = 1;
}

const pooling_backward::desc desc(
const mkldnn::pooling_backward::desc desc(
alg, diff_in_md, diff_md, {stride_h_, stride_w_},
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
mkldnn::padding_kind::zero);
const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
MKLDNNPoolingBwd bwd(pdesc, with_workspace);
it = AddToCache(&pooling_bwds, key, bwd);
}
Expand All @@ -401,14 +361,21 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,

auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
auto diff_src_mem =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);

mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
};
if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
}

bwd.SetNewMem(workspace, out_grad, diff_src_mem.second);
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
Loading

0 comments on commit 458bb73

Please sign in to comment.