Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace custom IOHW -> OIHW reorder with build-in oneDNN reorder #37175

Merged
merged 10 commits into from
Nov 17, 2021
44 changes: 13 additions & 31 deletions paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ConvTransposeMKLDNNHandlerT
PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The filter tensor's laytout should be %d, but got %d.",
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -229,32 +229,21 @@ class ConvTransposeMKLDNNHandlerT
weights_tz, platform::MKLDNNGetDataType<K>(),
(g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw);

auto iohw_weights_tz = framework::vectorize(filter->dims());
// Custom Reorder from IOHW to OIHW
auto iohw2oihw_reorder =
[&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> {
int o = iohw_weights_tz[1];
int c = iohw_weights_tz[0];
int h = iohw_weights_tz[2];
int w = iohw_weights_tz[3];
std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](),
std::default_delete<K[]>());
for (int i = 0; i < c; ++i) {
for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(K));
}
}

return reordered_filter_data;
};
const platform::MKLDNNEngine& engine = dev_ctx.GetEngine();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think this block of code is not needed. What should be modified is fragment "filter->format()" to "dnnl::format_tag::iohw" (guess) and then AcquireMemoryWithReorder will take care of rordering from IOHW -> blocked_optimized format. In your current changes you have two reorders: IOHW -> OIHW and then OIHW to blocked_optimized_format. One issue here is that filter-<format() is returning (probably) OIHW while in fact data is IOHW .

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto source_data_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::iohw);
auto reordered_data_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::oihw);
auto source_data_mem = platform::MKLDNNMemory(
source_data_md, engine, platform::to_void_cast<K>(filter_data));
auto reordered_data_mem = platform::MKLDNNMemory(reordered_data_md, engine);
platform::Reorder(source_data_mem, reordered_data_mem, engine);
astream.wait();

return this->template AcquireMemoryWithReorder<K>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "AcquireMemoryWithReorder" is called just after doing a reorder?

dev_ctx, user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), key, "@weights_mem_p", is_test_,
iohw2oihw_reorder);
reordered_data_mem.get_data_handle(), key, "@weights_mem_p", is_test_);
}

template <typename F = T>
Expand All @@ -263,7 +252,6 @@ class ConvTransposeMKLDNNHandlerT
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const std::string& key,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key + suffix + "_target";
const auto key_reorder_p = key + suffix + "reorder_p";
Expand All @@ -273,12 +261,6 @@ class ConvTransposeMKLDNNHandlerT
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));

if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
if (user_md != target_md) {
Expand Down