Skip to content

Commit

Permalink
Use memcpy if NDArray uses default format.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Dec 14, 2017
1 parent f00ced3 commit ad5dc2d
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/operator/tensor/cast_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ void CastStorageMKLDnsImpl(const OpContext& ctx, const NDArray& src, TBlob* dns)
auto src_mkldnn = src.GetMKLDNNData();
auto src_pd = src_mkldnn->get_primitive_desc();
auto def_format = GetDefaultFormat(src_pd.desc());
auto dst_pd = GetPrimitiveDesc(src_pd, def_format);
mkldnn::memory dst_mkldnn(dst_pd, dns->dptr_);
net.push_back(mkldnn::reorder(*src_mkldnn, dst_mkldnn));
mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
if (def_format != src_pd.desc().data.format) {
auto dst_pd = GetPrimitiveDesc(src_pd, def_format);
mkldnn::memory dst_mkldnn(dst_pd, dns->dptr_);
net.push_back(mkldnn::reorder(*src_mkldnn, dst_mkldnn));
mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
} else {
const TBlob &src_blob = src.data();
memcpy(dns->dptr_, src_blob.dptr_, src.shape().Size() * get_type_size(dns->type_flag_));
}
}

void CastStorageDnsMKLImpl(const OpContext& ctx, const NDArray& src, const NDArray &dst) {
Expand Down

0 comments on commit ad5dc2d

Please sign in to comment.