Skip to content

Commit

Permalink
Merge branch 'refactor' of https://github.com/zheng-da/incubator-mxnet
Browse files Browse the repository at this point in the history
…into refactor
  • Loading branch information
zheng-da committed Dec 14, 2017
2 parents 826592d + ad5dc2d commit a549065
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 98 deletions.
112 changes: 37 additions & 75 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,80 +160,6 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {

#if MXNET_USE_MKLDNN == 1

static inline mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) {
if (desc.data.ndims == 1) {
return desc.data.format;
} else if (desc.data.ndims == 2) {
if (desc.data.format == mkldnn_io)
return mkldnn_oi;
else
return desc.data.format;
} else if (desc.data.ndims == 4) {
switch (desc.data.format) {
case mkldnn_nchw:
case mkldnn_nhwc:
case mkldnn_chwn:
case mkldnn_nChw8c:
case mkldnn_nChw16c:
return mkldnn_nchw;
case mkldnn_oihw:
case mkldnn_ihwo:
case mkldnn_hwio:
case mkldnn_OIhw8i8o:
case mkldnn_OIhw16i16o:
case mkldnn_OIhw8i16o2i:
case mkldnn_OIhw8o16i2o:
case mkldnn_OIhw8o8i:
case mkldnn_OIhw16o16i:
case mkldnn_IOhw16o16i:
case mkldnn_Oihw8o:
case mkldnn_Oihw16o:
case mkldnn_Ohwi8o:
case mkldnn_Ohwi16o:
case mkldnn_OhIw16o4i:
return mkldnn_oihw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else if (desc.data.ndims == 5) {
switch (desc.data.format) {
case mkldnn_goihw:
case mkldnn_gOIhw8i8o:
case mkldnn_gOIhw16i16o:
case mkldnn_gOIhw8i16o2i:
case mkldnn_gOIhw8o16i2o:
case mkldnn_gOIhw8o8i:
case mkldnn_gOIhw16o16i:
case mkldnn_gIOhw16o16i:
case mkldnn_gOihw8o:
case mkldnn_gOihw16o:
case mkldnn_gOhwi8o:
case mkldnn_gOhwi16o:
case mkldnn_gOhIw16o4i:
return mkldnn_goihw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else {
LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims;
return mkldnn_format_undef;
}
}

static inline mkldnn::memory::primitive_desc GetPrimitiveDesc(
mkldnn::memory::primitive_desc pd, mkldnn_memory_format_t format) {
mkldnn::memory::dims dims(pd.desc().data.ndims);
for (size_t i = 0; i < dims.size(); i++)
dims[i] = pd.desc().data.dims[i];
mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(format);
mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
pd.desc().data.data_type);
mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
return mkldnn::memory::primitive_desc(data_md, pd.get_engine());
}

static inline mkldnn_mem_ptr Reorder2Default(mkldnn_mem_ptr mem,
bool submit_now = true) {
auto format = GetDefaultFormat(mem->get_primitive_desc().desc());
Expand Down Expand Up @@ -267,6 +193,8 @@ NDArray NDArray::ReshapeMKLDNN(const TShape &shape) const {
} else if (storage_type() == kMKLDNNStorage) {
NDArray ret(kMKLDNNStorage, shape, ctx(), ptr_->delay_alloc, dtype());
CHECK(ptr_->Mkl_mem_ != nullptr);
// This doesn't work on sliced NDArray yet.
CHECK_EQ(byte_offset_, 0);
// We shouldn't submit the reorder primitive here because submit will
// be called in operators.
auto format = GetDefaultFormat(ptr_->Mkl_mem_->get_primitive_desc().desc());
Expand Down Expand Up @@ -313,6 +241,8 @@ NDArray NDArray::Reshape(const TShape &shape) const {
} else {
ret.ptr_->Mkl_mem_ = this->ptr_->Mkl_mem_;
}
// We should make sure slice still works.
ret.byte_offset_ = this->byte_offset_;
}
}, ctx(), {this->var()}, {ret.var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE("SyncMKLDNN2Default"));
Expand Down Expand Up @@ -559,6 +489,8 @@ const mkldnn::memory *NDArray::GetMKLDNNData(
if (ptr_->Mkl_mem_->get_primitive_desc() == desc
|| (desc1.data.format == GetDefaultFormat(desc1)
&& desc2.data.format == GetDefaultFormat(desc2))) {
// This doesn't work on sliced NDArray yet.
CHECK_EQ(byte_offset_, 0);
MKLDNNStream::Get()->RegisterMem(ptr_->Mkl_mem_);
return GetMKLDNNExact(ptr_->Mkl_mem_.get(), desc);
} else {
Expand All @@ -572,6 +504,8 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
return nullptr;
}
// This doesn't work on sliced NDArray yet.
CHECK_EQ(byte_offset_, 0);
if (ptr_->storage_type == kDefaultStorage) {
ptr_->SetMKLMem(shape_, dtype_);
}
Expand Down Expand Up @@ -609,7 +543,33 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const {
ptr_->SetMKLMem(shape_, dtype_);
if (ptr_->Mkl_mem_) {
MKLDNNStream::Get()->RegisterMem(ptr_->Mkl_mem_);
return ptr_->Mkl_mem_.get();
if (byte_offset_ > 0) {
// Slice only works on the default layout and Slice() turns an array into
// the default layout.
auto pd = ptr_->Mkl_mem_->get_primitive_desc();
CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
void *off_addr = static_cast<char *>(ptr_->Mkl_mem_->get_data_handle())
+ byte_offset_;

// Create the primitive desc for the new mkldnn memory.
mkldnn::memory::dims dims(pd.desc().data.ndims);
// The first dimension has been sliced.
dims[0] = shape()[0];
for (size_t i = 1; i < dims.size(); i++)
dims[i] = pd.desc().data.dims[i];
mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(
pd.desc().data.format);
mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
pd.desc().data.data_type);
mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine());

std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(new_pd, off_addr));
MKLDNNStream::Get()->RegisterMem(ret);
return ret.get();
} else {
return ptr_->Mkl_mem_.get();
}
} else {
// We don't support converting sparse format.
return nullptr;
Expand All @@ -629,6 +589,8 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
return;
}

// This doesn't work on sliced NDArray yet.
CHECK_EQ(byte_offset_, 0);
MKLDNNStream *stream = MKLDNNStream::Get();
ptr_->SetMKLMem(shape_, dtype_);
stream->RegisterMem(ptr_->Mkl_mem_);
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
const mkldnn::engine &engine,
int num_groups = 1);

mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc);
mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd,
mkldnn_memory_format_t format);


} // namespace mxnet
#endif
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BASE_INL_H_
74 changes: 74 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,80 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
}
}

mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) {
if (desc.data.ndims == 1) {
return desc.data.format;
} else if (desc.data.ndims == 2) {
if (desc.data.format == mkldnn_io)
return mkldnn_oi;
else
return desc.data.format;
} else if (desc.data.ndims == 4) {
switch (desc.data.format) {
case mkldnn_nchw:
case mkldnn_nhwc:
case mkldnn_chwn:
case mkldnn_nChw8c:
case mkldnn_nChw16c:
return mkldnn_nchw;
case mkldnn_oihw:
case mkldnn_ihwo:
case mkldnn_hwio:
case mkldnn_OIhw8i8o:
case mkldnn_OIhw16i16o:
case mkldnn_OIhw8i16o2i:
case mkldnn_OIhw8o16i2o:
case mkldnn_OIhw8o8i:
case mkldnn_OIhw16o16i:
case mkldnn_IOhw16o16i:
case mkldnn_Oihw8o:
case mkldnn_Oihw16o:
case mkldnn_Ohwi8o:
case mkldnn_Ohwi16o:
case mkldnn_OhIw16o4i:
return mkldnn_oihw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else if (desc.data.ndims == 5) {
switch (desc.data.format) {
case mkldnn_goihw:
case mkldnn_gOIhw8i8o:
case mkldnn_gOIhw16i16o:
case mkldnn_gOIhw8i16o2i:
case mkldnn_gOIhw8o16i2o:
case mkldnn_gOIhw8o8i:
case mkldnn_gOIhw16o16i:
case mkldnn_gIOhw16o16i:
case mkldnn_gOihw8o:
case mkldnn_gOihw16o:
case mkldnn_gOhwi8o:
case mkldnn_gOhwi16o:
case mkldnn_gOhIw16o4i:
return mkldnn_goihw;
default:
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
} else {
LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims;
return mkldnn_format_undef;
}
}

mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd,
mkldnn_memory_format_t format) {
mkldnn::memory::dims dims(pd.desc().data.ndims);
for (size_t i = 0; i < dims.size(); i++)
dims[i] = pd.desc().data.dims[i];
mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(format);
mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
pd.desc().data.data_type);
mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
return mkldnn::memory::primitive_desc(data_md, pd.get_engine());
}

} // namespace mxnet

#endif
34 changes: 11 additions & 23 deletions src/operator/tensor/cast_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,21 @@ void CastStorageMKLDnsImpl(const OpContext& ctx, const NDArray& src, TBlob* dns)
CHECK_EQ(ctx.run_ctx.ctx.dev_mask(), Context::kCPU);
CHECK(src.shape() == dns->shape_);
CHECK_EQ(src.dtype(), dns->type_flag_);
// This converts the source data to the default format and copy the data to
// the destination.
// This converts the source data to the default format and write the data to
// the destination directly.
std::vector<mkldnn::primitive> net;
auto src_mkldnn = src.GetMKLDNNData();
auto src_pd = src_mkldnn->get_primitive_desc();
mkldnn::memory::dims dims(dns->shape_.ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = dns->shape_[i];
mkldnn::memory::format format = mkldnn::memory::format::format_undef;
switch (dims.size()) {
case 1: format = mkldnn::memory::format::x; break;
case 2: format = mkldnn::memory::format::nc; break;
case 4: format = mkldnn::memory::format::nchw; break;
// This isn't the right layout when the data has 5 dimensions in MXNet.
// MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
// a corresponding format.
case 5: format = mkldnn::memory::format::goihw; break;
auto def_format = GetDefaultFormat(src_pd.desc());
if (def_format != src_pd.desc().data.format) {
auto dst_pd = GetPrimitiveDesc(src_pd, def_format);
mkldnn::memory dst_mkldnn(dst_pd, dns->dptr_);
net.push_back(mkldnn::reorder(*src_mkldnn, dst_mkldnn));
mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
} else {
const TBlob &src_blob = src.data();
memcpy(dns->dptr_, src_blob.dptr_, src.shape().Size() * get_type_size(dns->type_flag_));
}
CHECK_NE(format, mkldnn::memory::format::format_undef);
mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(format);
mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
src_pd.desc().data.data_type);
mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
mkldnn::memory::primitive_desc dst_pd(data_md, src_pd.get_engine());
mkldnn::memory dst_mkldnn(dst_pd, dns->dptr_);
net.push_back(mkldnn::reorder(*src_mkldnn, dst_mkldnn));
mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
}

void CastStorageDnsMKLImpl(const OpContext& ctx, const NDArray& src, const NDArray &dst) {
Expand Down

0 comments on commit a549065

Please sign in to comment.