Skip to content

Commit

Permalink
Enable MKL-DNN FullyConnected backward (apache#17318)
Browse files Browse the repository at this point in the history
* fix mkldnn fc bwd bug due to data inplace

* enable mkldnn fc bwd

* fix cpp tests

* try: fix random seed

* fix cpp test

* loose rtol for fc cpp test

* improve error message

* limit max value for mkldnn tensors

* limit the max value of test tensors

* fix lint

* remove fixed random seed

* address review comments

* Revert "address review comments"

This reverts commit 56d873f.

Co-authored-by: rongzha1 <[email protected]>
  • Loading branch information
2 people authored and anirudh2290 committed May 29, 2020
1 parent 12b46cb commit 2b78831
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 38 deletions.
9 changes: 2 additions & 7 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp
// Will be fixed when we decide to enable the backward of FC.
bool mkldnn_fc_backward_enable = false;
if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) {
if (SupportMKLDNNFC(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
Expand Down Expand Up @@ -233,12 +230,10 @@ static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), out_expected);
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
Expand Down
44 changes: 22 additions & 22 deletions tests/cpp/include/test_mkldnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,24 @@ struct TestArrayShapes {
};

// Init arrays with the default layout.
inline static void InitDefaultArray(NDArray *arr, bool is_rand = false) {
inline static void InitDefaultArray(NDArray *arr, bool is_rand = false, int max = 50) {
const TBlob &blob = arr->data();
mshadow::default_real_t *data = blob.dptr<mshadow::default_real_t>();
int size = blob.Size();

for (int i = 0; i < size; i++)
if (is_rand) {
data[i] = (std::rand() % 100) - 50;
data[i] = (std::rand() % (max * 2)) - max;
} else {
data[i] = i % 100 - 50;
data[i] = i % (max * 2) - max;
}
}


// Init arrays with the specified layout.
inline static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::desc &desc,
bool is_rand = false) {
InitDefaultArray(arr, is_rand);
bool is_rand = false, int max = 50) {
InitDefaultArray(arr, is_rand, max);
arr->MKLDNNDataReorderAsync(desc);
arr->WaitToRead();
}
Expand Down Expand Up @@ -330,7 +330,7 @@ inline void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
*/
inline std::vector<NDArrayAttrs> GetTestInputArrays(
int types = ArrayTypes::All, bool rand = false,
std::vector<float> scale = {1}, bool spatial_data_format = false) {
std::vector<float> scale = {1}, bool spatial_data_format = false, int max = 50) {
TestArrayShapes tas = GetTestArrayShapes(spatial_data_format);
std::vector<mxnet::TShape> shapes = tas.shapes;
std::vector<mkldnn::memory::desc> mds = tas.mds;
Expand All @@ -349,14 +349,14 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
// Type 1.
NDArray arr(shape, Context());
if (types & ArrayTypes::Normal) {
InitDefaultArray(&arr, rand);
InitDefaultArray(&arr, rand, max);
in_arrs.emplace_back(arr, "Normal NDArray");
}

// Type 4
arr = NDArray(shape, Context());
if (types & ArrayTypes::NormalReshaped) {
InitDefaultArray(&arr, rand);
InitDefaultArray(&arr, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount),
"Reshaped Normal NDArray");
}
Expand All @@ -379,19 +379,19 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
if (shape.ndim() == md.data.ndims && IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNN) {
desc_str = "MKLDNN NDArray";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr, desc_str);
} else if (shape.ndim() == md.data.ndims && !IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNNDiffShape) {
desc_str = "MKLDNN NDArray with different shape";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr, desc_str);
} else if (shape.ndim() != md.data.ndims && types & ArrayTypes::MKLDNNDiffDim) {
std::stringstream ss;
ss << "MKLDNN NDArray with different dim " <<
shape.ndim() << "/" << md.data.ndims;
desc_str = ss.str();
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr, desc_str);
}

Expand All @@ -401,20 +401,20 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
if (shape.ndim() == md.data.ndims && IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNNReshaped) {
desc_str = "Reshaped MKLDNN NDArray";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str);
} else if (shape.ndim() == md.data.ndims && !IsSameShape(md, shape)
&& types & ArrayTypes::MKLDNNReshapedDiffShape) {
desc_str = "Reshaped MKLDNN NDArray with different shape";
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str);
} else if (shape.ndim() != md.data.ndims
&& types & ArrayTypes::MKLDNNReshapedDiffDim) {
std::stringstream ss;
ss << "MKLDNN NDArray with different dim " <<
shape.ndim() << "/" << md.data.ndims;
desc_str = ss.str();
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc_str);
}
}
Expand Down Expand Up @@ -445,7 +445,7 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
inline std::vector<NDArrayAttrs> GetTestOutputArrays(
const mxnet::TShape &shp,
const std::vector<mkldnn::memory::desc> &mds,
std::vector<float>scale = {1}, bool rand = true, int types = ArrayTypes::All) {
std::vector<float>scale = {1}, bool rand = true, int types = ArrayTypes::All, int max = 50) {
mxnet::TShape shape = shp;

for (int dim = 0; dim < scale.size(); dim++)
Expand All @@ -458,15 +458,15 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(

if (types & ArrayTypes::Normal) {
in_arrs.emplace_back(arr, "Normal NDArray");
InitDefaultArray(&in_arrs.back().arr, rand);
InitDefaultArray(&in_arrs.back().arr, rand, max);
}

mxnet::TShape tmp_shape = shape;
if (types & ArrayTypes::NormalReshaped) {
// Type 4.
tmp_shape[0] = shape[0] * 2;
NDArray arr0(tmp_shape, Context());
InitDefaultArray(&arr0, rand);
InitDefaultArray(&arr0, rand, max);
in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1), "Reshaped NDArray");
}

Expand All @@ -477,7 +477,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
s[0] = shape.Size();
NDArray arr1(s, Context());
arr1 = arr1.AsArray(shape, arr1.dtype());
InitDefaultArray(&arr1, rand);
InitDefaultArray(&arr1, rand, max);
in_arrs.emplace_back(arr1, "Reused NDArray");
}

Expand All @@ -486,7 +486,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag);
NDArray arr2(s, Context(), true, mshadow::kUint8);
arr2 = arr2.AsArray(shape, mshadow::default_type_flag);
InitDefaultArray(&arr2, rand);
InitDefaultArray(&arr2, rand, max);
in_arrs.emplace_back(arr2, "Reused NDArray with diff data type");
}

Expand All @@ -496,7 +496,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
NDArray arr3(s, Context(), true, mshadow::kUint8);
tmp_shape[0] = shape[0] * 2;
arr3 = arr3.AsArray(tmp_shape, mshadow::default_type_flag);
InitDefaultArray(&arr3, rand);
InitDefaultArray(&arr3, rand, max);
in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1), "Reused+Reshaped NDArray");
}

Expand All @@ -523,7 +523,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
if ((types & ArrayTypes::MKLDNN && shape.ndim() == md.data.ndims) ||
(types & ArrayTypes::MKLDNNDiffDim && shape.ndim() != md.data.ndims)) {
in_arrs.emplace_back(arr, desc_str);
InitMKLDNNArray(&in_arrs.back().arr, md, rand);
InitMKLDNNArray(&in_arrs.back().arr, md, rand, max);
}

// Type 8, 9.
Expand All @@ -532,7 +532,7 @@ inline std::vector<NDArrayAttrs> GetTestOutputArrays(
s[0] = shape.Size();
NDArray arr = NDArray(s, Context());
arr = arr.AsArray(shape, arr.dtype());
InitMKLDNNArray(&arr, md, rand);
InitMKLDNNArray(&arr, md, rand, max);
desc_str = "Reused MKLDNN NDArray";
if (shape.ndim() != md.data.ndims) {
std::stringstream ss;
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/include/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,8 @@ static void AssertEqual(const std::vector<NDArray *> &in_arrs,
static_cast<mshadow::default_real_t *>(blob2.dptr_);
for (int i = 0; i < tmp1.shape().Size(); i++) {
float abs_err = fabs((d1[i]) - (d2[i]));
ASSERT_LE(abs_err, (atol + rtol * fabs(d2[i])));
ASSERT_LE(abs_err, (atol + rtol * fabs(d2[i])))
<< "index: " << i << ", " << d1[i] << " vs " << d2[i];
}
}
}
Expand Down
17 changes: 9 additions & 8 deletions tests/cpp/operator/mkldnn_operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,9 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {

for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, false, forward_attrs.output_types);
ex_out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(in_arr.arr.shape(), mds, {1}, false, forward_attrs.output_types);
}

for (int i = 0; i < forward_attrs.num_inputs; i++)
Expand Down Expand Up @@ -897,7 +897,8 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::desc> mds = tas.mds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, true);
std::vector<NDArrayAttrs> in_arrs =
GetTestInputArrays(forward_attrs.input_types, true, {1}, false, 1);
std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);

Expand Down Expand Up @@ -932,9 +933,9 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards

for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(out_shape, mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(out_shape, mds, {1}, false, forward_attrs.output_types, 1);
ex_out_arrs[i] =
GetTestOutputArrays(out_shape, mds, {1}, forward_attrs.output_types);
GetTestOutputArrays(out_shape, mds, {1}, false, forward_attrs.output_types, 1);
}

for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
Expand All @@ -960,14 +961,14 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards
backwards_input[1] = inputs[0]; // input
backwards_input[2] = inputs[1]; // weights

auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1];
auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true, {1}, false, 1)[i1];
NDArray back_weights(wt_shape, Context());
NDArray back_bias(bias_shape, Context());
backwards_outputs[0] = &tmp_output.arr;
backwards_outputs[1] = &back_weights;
backwards_outputs[2] = &back_bias;

auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1];
auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true, {1}, false, 1)[i1];
NDArray back_ex_weights(wt_shape, Context());
NDArray back_ex_bias(bias_shape, Context());
backwards_ex_outputs[0] = &tmp_output2.arr;
Expand All @@ -986,7 +987,7 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards
Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(backwards_outputs, backwards_ex_outputs);
AssertEqual(backwards_outputs, backwards_ex_outputs, 1e-6, 1e-6);
}
}
}
Expand Down

0 comments on commit 2b78831

Please sign in to comment.