diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 1b2449b137a6..ce5d0b1e6917 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -91,6 +91,11 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data, if (this->with_workspace_) { auto engine = CpuEngine::Get()->get_engine(); + + if (workspace == nullptr) { + LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; + } + auto ws = std::make_shared((*(this->fwd_pd_)).workspace_desc(), engine, workspace->GetMKLDNNData()->get_data_handle()); args[MKLDNN_ARG_WORKSPACE] = *ws; diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc index bc69cb5e9bf7..9c2097ba7cf1 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_act.cc @@ -22,7 +22,7 @@ * \brief MKLDNN(Quantized) Activation operator based on subgraph * /author Zhiyuan Huang */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../../nn/mkldnn/mkldnn_act-inl.h" #include "../quantization_utils.h" diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc index 31da936915e6..2416c128eddd 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc @@ -23,7 +23,7 @@ * \brief */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../../nn/mkldnn/mkldnn_flatten-inl.h" #include "../quantization_utils.h" diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc index 07e14412618d..4c62c97f5b39 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc @@ -23,7 +23,7 @@ * \author Tao Lv, Xinyu Chen */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../../nn/mkldnn/mkldnn_pooling-inl.h" @@ -38,9 +38,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op || in_data[0].dtype() == mshadow::kInt8) << "mkldnn_quantized_pooling op only supports uint8 and int8 as input type"; const PoolingParam& param = nnvm::get(attrs.parsed); - auto fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0]); - fwd.SetNewMem(in_data[0], out_data[0], req[0]); - fwd.Execute(out_data[0]); + MKLDNNPoolingCompute(ctx, param, in_data[0], req[0], out_data[0], nullptr); out_data[1].data().dptr()[0] = in_data[1].data().dptr()[0]; out_data[2].data().dptr()[0] = in_data[2].data().dptr()[0]; } diff --git a/src/operator/quantization/quantized_activation.cc b/src/operator/quantization/quantized_activation.cc index 40a28d6bb018..f054bf0e1858 100644 --- a/src/operator/quantization/quantized_activation.cc +++ b/src/operator/quantization/quantized_activation.cc @@ -68,7 +68,7 @@ inline static bool QuantizedActivationStorageType(const nnvm::NodeAttrs &attrs, CHECK_EQ(in_attrs->size(), 3); *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 const ActivationParam ¶m = nnvm::get(attrs.parsed); if (dev_mask == mshadow::cpu::kDevMask && param.act_type == activation::kReLU) { *dispatch_mode = DispatchMode::kFComputeEx; diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc index 1839e2a29d77..8934ddb80d2a 100644 --- a/src/operator/quantization/quantized_pooling.cc +++ b/src/operator/quantization/quantized_pooling.cc @@ -23,7 +23,7 @@ */ #include #include "../nn/pooling-inl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../nn/mkldnn/mkldnn_pooling-inl.h" #endif @@ -98,7 +98,7 @@ bool QuantizedPoolingType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_type->size(), 3U); CHECK_EQ(out_type->size(), 3U); if (param.pool_type == pool_enum::kMaxPooling || param.pool_type == pool_enum::kAvgPooling) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 TYPE_ASSIGN_CHECK(*out_type, 0, (*in_type)[0]); #else TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); @@ -122,7 +122,7 @@ inline static bool QuantizedPoolingStorageType(const nnvm::NodeAttrs &attrs, CHECK_EQ(in_attrs->size(), 3); *dispatch_mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 const PoolingParam ¶m = nnvm::get(attrs.parsed); if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNPooling(param)) { *dispatch_mode = DispatchMode::kFComputeEx;