Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom IOHW -> OIHW reorder with build-in oneDNN reorder #37175

Merged
merged 10 commits into from
Nov 17, 2021
44 changes: 13 additions & 31 deletions paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ConvTransposeMKLDNNHandlerT
PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The filter tensor's laytout should be %d, but got %d.",
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -229,32 +229,21 @@ class ConvTransposeMKLDNNHandlerT
weights_tz, platform::MKLDNNGetDataType<K>(),
(g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw);

auto iohw_weights_tz = framework::vectorize(filter->dims());
// Custom Reorder from IOHW to OIHW
auto iohw2oihw_reorder =
[&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> {
int o = iohw_weights_tz[1];
int c = iohw_weights_tz[0];
int h = iohw_weights_tz[2];
int w = iohw_weights_tz[3];
std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](),
std::default_delete<K[]>());
for (int i = 0; i < c; ++i) {
for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(K));
}
}

return reordered_filter_data;
};
platform::MKLDNNEngine engine = dev_ctx.GetEngine();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
platform::MKLDNNEngine engine = dev_ctx.GetEngine();
const platform::MKLDNNEngine& engine = dev_ctx.GetEngine();

platform::MKLDNNStream engine_stream(engine);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
platform::MKLDNNStream engine_stream(engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

auto source_data_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::iohw);
auto reordered_data_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::oihw);
auto source_data_mem = platform::MKLDNNMemory(
source_data_md, engine, platform::to_void_cast<K>(filter_data));
auto reordered_data_mem = platform::MKLDNNMemory(reordered_data_md, engine);
platform::Reorder(source_data_mem, reordered_data_mem, engine);
engine_stream.wait();

return this->template AcquireMemoryWithReorder<K>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "AcquireMemoryWithReorder" is called just after doing a reorder?

dev_ctx, user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), key, "@weights_mem_p", is_test_,
iohw2oihw_reorder);
reordered_data_mem.get_data_handle(), key, "@weights_mem_p", is_test_);
}

template <typename F = T>
Expand All @@ -263,7 +252,6 @@ class ConvTransposeMKLDNNHandlerT
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const std::string& key,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key + suffix + "_target";
const auto key_reorder_p = key + suffix + "reorder_p";
Expand All @@ -273,12 +261,6 @@ class ConvTransposeMKLDNNHandlerT
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));

if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
if (user_md != target_md) {
Expand Down