diff --git a/python/tvm/contrib/dnnl.py b/python/tvm/contrib/dnnl.py index b9b77a2d20ae..a72219994330 100644 --- a/python/tvm/contrib/dnnl.py +++ b/python/tvm/contrib/dnnl.py @@ -113,6 +113,9 @@ def dnnl_conv2d( else: dilation_h, dilation_w = dilation + pre_cast = src.dtype == "float32" + post_cast = out_dtype == "float32" + if channel_last: batch, in_height, in_width, _ = src.shape kernel_h, kernel_w, _, num_filter = weights.shape @@ -150,6 +153,8 @@ def dnnl_conv2d( stride[1], groups, channel_last, + pre_cast, + post_cast, ), name="C", dtype=out_dtype, diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index f3c3e9d0ea21..9ef464064753 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -81,7 +81,7 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_, int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr, - bool channel_last) { + bool channel_last, bool pre_cast, bool post_cast) { using tag = memory::format_tag; using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -98,20 +98,31 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_}; auto user_src_memory = - memory({{conv2d_src_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, data); - auto user_weights_memory = - memory({{conv2d_weights_tz}, dt::f32, channel_last ? tag::hwio : tag::oihw}, eng, weights); + memory({{conv2d_src_tz}, pre_cast ? dt::f32 : dt::bf16, channel_last ? tag::nhwc : tag::nchw}, + eng, data); + auto user_weights_memory = memory({{conv2d_weights_tz}, + (pre_cast && post_cast) ? dt::f32 : dt::bf16, + channel_last ? tag::hwio : tag::oihw}, + eng, weights); if (p_G_ > 1) - user_weights_memory = memory( - {{conv2d_weights_tz}, dt::f32, channel_last ? tag::ghwio : tag::goihw}, eng, weights); - auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias); - auto user_dst_memory = - memory({{conv2d_dst_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, out); - - auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); - auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); - auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any); - auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::any); + user_weights_memory = memory({{conv2d_weights_tz}, + (pre_cast && post_cast) ? dt::f32 : dt::bf16, + channel_last ? tag::ghwio : tag::goihw}, + eng, weights); + auto conv2d_user_bias_memory = + memory({{conv2d_bias_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::x}, eng, bias); + auto user_dst_memory = memory( + {{conv2d_dst_tz}, post_cast ? dt::f32 : dt::bf16, channel_last ? tag::nhwc : tag::nchw}, eng, + out); + + auto conv2d_src_md = + memory::desc({conv2d_src_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any); + auto conv2d_bias_md = + memory::desc({conv2d_bias_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any); + auto conv2d_weights_md = + memory::desc({conv2d_weights_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any); + auto conv2d_dst_md = + memory::desc({conv2d_dst_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any); auto conv2d_desc = convolution_forward::desc( prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md, @@ -161,8 +172,8 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, i primitive_attr attr; std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, - p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, - false); + p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, false, + true, true); } primitive_attr create_attr_with_relu_post_op() { @@ -182,7 +193,7 @@ extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, - create_attr_with_relu_post_op(), false); + create_attr_with_relu_post_op(), false, true, true); } extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out, @@ -192,7 +203,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* int p_Sw_) { return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, - create_attr_with_relu_post_op(), false); + create_attr_with_relu_post_op(), false, true, true); } extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { @@ -345,6 +356,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body([](TVMArgs args, TVMRetV int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], p_Sh_ = args[7], p_Sw_ = args[8], p_G_ = args[9]; bool channel_last = args[10]; + bool pre_cast = args[11]; + bool post_cast = args[12]; int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2], p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2], @@ -365,7 +378,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body([](TVMArgs args, TVMRetV return dnnl_conv2d_common(static_cast(input->data), static_cast(weights->data), bias.data(), static_cast(output->data), p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, - attr, channel_last); + attr, channel_last, pre_cast, post_cast); }); } // namespace contrib diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index dd6a54b959cc..84b72e4cffd2 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2003,6 +2003,22 @@ def test_conv2d_rocm_sdot4(): np.testing.assert_equal(out, ref) +def np_float2tvm_bf16(arr): + """Convert a numpy array of float to a TVM array + of bf16""" + orig = arr.view("