Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-483] C++ tests for mkldnn convolution/deconvolution operator (#…
Browse files Browse the repository at this point in the history
…11778)

* add conv test

* remove pool type

* gettestinput can receive scale input

* create kernels and bias arrays

* fix shape of kernel / bias

* fix format

* bias is 1dim

* fix output shape

* fix backwards input

* fix var name to backwards_ex_outputs

* filter inputs with diff memory dims

* fix lint

* remove extra spaces

* fix lint

* add deconv test

* add calc devconv size

* remove bias from deconv input

* create deconv kernel

* fix num outputs for deconv

* fix lint

* use random inputs for deconv

* can init random mldnn array

* round scale

* update for loops with size_t instead of ints

* remove comment

* fix merge

* use bounded random inputs

* fix merge issue

* conv op uses filter

* reorder if view

* reorder backwards

* rename to out_grad

* fix lint

* filter pooling tpyes

* reorder

* add bias

* fix typo

* fix ref

* filter arrays

* reorder devcon inputs

* reorder devonc forward inputs

* fix missing var

* remove unused var

* reorder inputs for deconv forward

* remove const

* avoid reorder

* set bias

* fix typo

* set bias with string

* set bias with string

* remove use bias

* add bias

* add bias shape

* cannot use reshaped non n*** format

* add spatial filter

* fix conv

* fix missing conv

* fix input

* fix merge

* add missing header

* add inline

* fix input

* add spatial filter in test

* fix get test input params

* fix get test input params

* fix num inputs

* fix input num of backwards

* fix bias

* add missing bias

* fix output num for backwards

* fix num outputs for deconv

* fix test input

* remove comments

* use deconv param

* use template

* filter out incompatible widths

* fix lint

* update comment

* reorder weights in deconv

* remove const from set_data handle

* fix lint

* retrigger

* remove data format

* retrigger

* retrigger

* use check_eq

* retrriger

* refactor deconv if/else block
  • Loading branch information
azai91 authored and anirudh2290 committed Nov 20, 2018
1 parent 5a83b6b commit 91c536d
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 52 deletions.
49 changes: 35 additions & 14 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,18 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
NDArray weight = in_data[conv::kWeight];

auto data = in_data[conv::kData];
if (data.IsView() && data.IsMKLDNNData())
data = data.Reorder2Default();

auto weight = in_data[conv::kWeight];
if (weight.IsView() && weight.IsMKLDNNData())
weight = weight.Reorder2Default();

bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn;
auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(

auto data_mem = data.GetMKLDNNDataReorder(
fwd->fwd_pd.src_primitive_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
Expand Down Expand Up @@ -577,19 +586,32 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
MKLDNNConvFullParam full_param;
full_param.conv_param = nnvm::get<ConvolutionParam>(attrs.parsed);
full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());

auto data = inputs[conv::kData + 1];
if (data.IsView() && data.IsMKLDNNData())
data = data.Reorder2Default();

auto weight = inputs[conv::kWeight + 1];
if (weight.IsView() && weight.IsMKLDNNData())
weight = weight.Reorder2Default();

const NDArray* bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1];

auto out_grad = inputs[conv::kOut];
if (out_grad.IsView() && out_grad.IsMKLDNNData())
out_grad = out_grad.Reorder2Default();

mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(
full_param, ctx.is_train, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1],
inputs[conv::kOut]);
full_param, ctx.is_train, data, weight, bias, out_grad);
const ConvolutionParam &param = full_param.conv_param;

CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1],
inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd);
auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
MKLDNNConvBackward &convBwd = GetConvBwd(attrs, data,
weight, bias, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
convBwd.bwdData_pd.diff_dst_primitive_desc());
if (req[conv::kData]) {
auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
auto weight_mem = GetWeights(weight,
convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
Expand All @@ -598,14 +620,13 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight]) {
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1],
inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1],
inputs[conv::kOut], fwd_pd);
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data,
weight, bias, out_grad, fwd_pd);
if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
out_grad_mem = out_grad.GetMKLDNNDataReorder(
convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
auto data_mem = data.GetMKLDNNDataReorder(
convBwdWeight.bwdWeights_pd.src_primitive_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(
in_grad[conv::kWeight],
Expand Down
61 changes: 38 additions & 23 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ class MKLDNNDeconvForward {
const NDArray &output);
void SetDataHandle(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const NDArray &in_data,
const NDArray &weight,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

Expand Down Expand Up @@ -243,32 +244,30 @@ MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param,

void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const NDArray &in_data,
const NDArray &weight,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder(
auto data_mem = in_data.GetMKLDNNDataReorder(
fwd_pd.diff_dst_primitive_desc());
NDArray weight = in_data[deconv::kWeight];
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
// TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
// to the default format for now.
if (weight.IsMKLDNNData())
// This asks the engine to reorder data after the weight array is used.
weight.Reorder2DefaultAsync();
const_cast<NDArray&>(weight).Reorder2DefaultAsync();
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
} else {
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
weight.MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
}
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
Expand All @@ -287,19 +286,19 @@ void MKLDNNDeconvForward::Execute(const std::vector<NDArray> &out_data) {

static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const NDArray &bias,
const std::vector<NDArray> &out_data) {
// add bias, broadcast bias to dim 1: channel
if (!param.no_bias) {
// MKLDNN only supports float right now.
typedef float DType;
Stream<cpu> *s = ctx.get_stream<cpu>();
Tensor<cpu, 1, DType> bias = in_data[deconv::kBias].data().get<cpu, 1, DType>(s);
Tensor<cpu, 1, DType> b = bias.data().get<cpu, 1, DType>(s);
// If the output data is stored in a special MKLDNN format, data()
// automatically converts its format to the default format.
// Unfortunately, MKLDNN doesn't support broadcast.
Tensor<cpu, 4, DType> out_cpu = out_data[deconv::kOut].data().get<cpu, 4, DType>(s);
out_cpu += mshadow::expr::broadcast<1>(bias, out_cpu.shape_);
out_cpu += mshadow::expr::broadcast<1>(b, out_cpu.shape_);
}
}

Expand Down Expand Up @@ -344,15 +343,24 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);

auto data = in_data[deconv::kData];
if (data.IsView() && data.IsMKLDNNData())
data = data.Reorder2Default();

auto weight = in_data[deconv::kWeight];
if (weight.IsView() && weight.IsMKLDNNData())
weight = weight.Reorder2Default();

const NDArray* bias = param.no_bias ? nullptr : &in_data[deconv::kBias];

MKLDNNDeconvForward &deconvFwd = GetDeconvFwd(
attrs, in_data[deconv::kData], in_data[deconv::kWeight],
param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]);
attrs, data, weight, bias, out_data[deconv::kOut]);

deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data);
deconvFwd.SetDataHandle(param, ctx, data, weight, req, out_data);

deconvFwd.Execute(out_data);

MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
MKLDNNDeconvFwdBiasPostProcess(param, ctx, *bias, out_data);
}

class MKLDNNDeconvBackwardData {
Expand Down Expand Up @@ -506,17 +514,24 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const std::vector<NDArray> &in_grad = outputs;
const DeconvolutionParam &param = nnvm::get<DeconvolutionParam>(attrs.parsed);

auto data = inputs[deconv::kData + 1];
if (data.IsView() && data.IsMKLDNNData())
data = data.Reorder2Default();

auto weight = inputs[deconv::kWeight + 1];
if (weight.IsView() && weight.IsMKLDNNData())
weight = weight.Reorder2Default();

CHECK_NE(req[deconv::kWeight], kWriteInplace)
<< "cannot write weight inplace";
MKLDNNDeconvBackwardData &bwd_data =
GetDeconvBwdData(param, inputs[deconv::kData + 1],
inputs[deconv::kWeight + 1], inputs[deconv::kOut]);
GetDeconvBwdData(param, data, weight, inputs[deconv::kOut]);
auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
bwd_data.pd.src_primitive_desc());
if (req[deconv::kData]) {
auto weight_mem =
GetWeights(inputs[deconv::kWeight + 1],
bwd_data.pd.weights_primitive_desc(), param.num_group);
GetWeights(weight, bwd_data.pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem =
CreateMKLDNNMem(in_grad[deconv::kData],
bwd_data.pd.dst_primitive_desc(), req[deconv::kData]);
Expand All @@ -526,12 +541,12 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs &attrs,
}
if (req[deconv::kWeight]) {
MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights(
param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1],
param, data, weight,
inputs[deconv::kOut], bwd_data.pd);
if (bwd_data.pd.src_primitive_desc() != bwd_weights.pd.src_primitive_desc())
out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
bwd_weights.pd.src_primitive_desc());
auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
auto data_mem = data.GetMKLDNNDataReorder(
bwd_weights.pd.diff_dst_primitive_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(
in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(),
Expand Down
62 changes: 49 additions & 13 deletions tests/cpp/include/test_mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ inline static std::vector<mkldnn::memory::format> GetMKLDNNFormat(size_t num_dim
}
}

inline static TestArrayShapes GetTestArrayShapes() {
inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = false) {
int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
std::vector<TShape> shapes;
std::vector<mkldnn::memory::primitive_desc> pds;
Expand Down Expand Up @@ -198,7 +198,9 @@ inline static TestArrayShapes GetTestArrayShapes() {
pds.push_back(GetMemPD(s2, dtype, mkldnn::memory::format::oihw));

std::vector<mkldnn::memory::format> formats = GetMKLDNNFormat(4, dtype);
pds.push_back(GetMemPD(s1, dtype, formats[0]));
if (!spatial_data_format) {
pds.push_back(GetMemPD(s1, dtype, formats[0]));
}
}
{
// 5D
Expand All @@ -208,7 +210,9 @@ inline static TestArrayShapes GetTestArrayShapes() {
pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::goihw));

std::vector<mkldnn::memory::format> formats = GetMKLDNNFormat(5, dtype);
pds.push_back(GetMemPD(s, dtype, formats[0]));
if (!spatial_data_format) {
pds.push_back(GetMemPD(s, dtype, formats[0]));
}
}

TestArrayShapes ret;
Expand Down Expand Up @@ -250,6 +254,38 @@ enum ArrayTypes {
All = 8191,
};


inline NDArray CreateKernelNDArray(TShape kernel, int num_filters, TShape input,
bool is_deconv = false) {
CHECK_EQ(kernel.ndim(), 2) << "mkldnn only supports 2d filters on 4d inputs";
TShape target_shape(4);
target_shape[0] = is_deconv ? input[1] : num_filters;
target_shape[1] = is_deconv ? num_filters : input[1];
target_shape[2] = kernel[0];
target_shape[3] = kernel[1];
int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
NDArray arr(target_shape, Context());
auto pd = GetMemPD(target_shape, dtype, mkldnn::memory::format::nchw);
InitMKLDNNArray(&arr, pd);
return arr;
}

inline NDArray CreateBiasNDArray(TShape target_shape) {
int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
NDArray arr(target_shape, Context());
auto pd = GetMemPD(target_shape, dtype, mkldnn::memory::format::x);
InitMKLDNNArray(&arr, pd);
return arr;
}

inline int CalculateWidthConvOutput(int width, int kernel, int padding, int stride) {
return (width - kernel + 2 * padding) / stride + 1;
}

inline int CalculateWidthDeconvOutput(int width, int kernel, int padding, int stride) {
return stride * (width - 1) + kernel - 2 * padding;
}

inline std::string CreateShapeString(int value, int dim) {
std::stringstream ss;
ss << "(";
Expand Down Expand Up @@ -293,21 +329,21 @@ inline void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
*/
inline std::vector<NDArrayAttrs> GetTestInputArrays(
int types = ArrayTypes::All, bool rand = false,
int num_inputs = 1, int dim = 0) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<float> scale = {1}, bool spatial_data_format = false) {
TestArrayShapes tas = GetTestArrayShapes(spatial_data_format);
std::vector<nnvm::TShape> shapes = tas.shapes;
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs;
std::string desc;

int slice_amount = 1;
if (dim == 0)
slice_amount = num_inputs;
int slice_amount = scale[0];
for (auto shape : shapes) {
if (dim >= shape.ndim())
if (scale.size() > shape.ndim())
continue;
shape[dim] = shape[dim] * num_inputs;

for (size_t dim = 0; dim < scale.size(); ++dim)
shape[dim] = static_cast<int>(round(shape[dim] * scale[dim]));

// Type 1.
NDArray arr(shape, Context());
Expand All @@ -326,12 +362,12 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(


for (auto pd : pds) {
if (num_inputs > 1) {
for (size_t dim = 0; dim < scale.size(); ++dim) {
// preserve if matching layout else just expand on 0 dim
if (shape.ndim() == pd.desc().data.ndims)
pd = GetExpandedMemPD(pd, num_inputs, dim);
pd = GetExpandedMemPD(pd, scale[dim], dim);
else
pd = GetExpandedMemPD(pd, num_inputs);
pd = GetExpandedMemPD(pd, scale[dim]);
}

if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
Expand Down
Loading

0 comments on commit 91c536d

Please sign in to comment.