Skip to content

Commit

Permalink
[DNNL] Add bfloat16 type support for dnnl conv2d kernel (apache#11902)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qianshui-Jiang authored and blackkker committed Jul 7, 2022
1 parent 575c249 commit addfe2c
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 71 deletions.
5 changes: 5 additions & 0 deletions python/tvm/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def dnnl_conv2d(
else:
dilation_h, dilation_w = dilation

pre_cast = src.dtype == "float32"
post_cast = out_dtype == "float32"

if channel_last:
batch, in_height, in_width, _ = src.shape
kernel_h, kernel_w, _, num_filter = weights.shape
Expand Down Expand Up @@ -150,6 +153,8 @@ def dnnl_conv2d(
stride[1],
groups,
channel_last,
pre_cast,
post_cast,
),
name="C",
dtype=out_dtype,
Expand Down
51 changes: 32 additions & 19 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) {
void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_,
int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_,
int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr,
bool channel_last) {
bool channel_last, bool pre_cast, bool post_cast) {
using tag = memory::format_tag;
using dt = memory::data_type;
engine eng(engine::kind::cpu, 0);
Expand All @@ -98,20 +98,31 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in
memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_};

auto user_src_memory =
memory({{conv2d_src_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, data);
auto user_weights_memory =
memory({{conv2d_weights_tz}, dt::f32, channel_last ? tag::hwio : tag::oihw}, eng, weights);
memory({{conv2d_src_tz}, pre_cast ? dt::f32 : dt::bf16, channel_last ? tag::nhwc : tag::nchw},
eng, data);
auto user_weights_memory = memory({{conv2d_weights_tz},
(pre_cast && post_cast) ? dt::f32 : dt::bf16,
channel_last ? tag::hwio : tag::oihw},
eng, weights);
if (p_G_ > 1)
user_weights_memory = memory(
{{conv2d_weights_tz}, dt::f32, channel_last ? tag::ghwio : tag::goihw}, eng, weights);
auto conv2d_user_bias_memory = memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, bias);
auto user_dst_memory =
memory({{conv2d_dst_tz}, dt::f32, channel_last ? tag::nhwc : tag::nchw}, eng, out);

auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any);
auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any);
auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any);
auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::any);
user_weights_memory = memory({{conv2d_weights_tz},
(pre_cast && post_cast) ? dt::f32 : dt::bf16,
channel_last ? tag::ghwio : tag::goihw},
eng, weights);
auto conv2d_user_bias_memory =
memory({{conv2d_bias_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::x}, eng, bias);
auto user_dst_memory = memory(
{{conv2d_dst_tz}, post_cast ? dt::f32 : dt::bf16, channel_last ? tag::nhwc : tag::nchw}, eng,
out);

auto conv2d_src_md =
memory::desc({conv2d_src_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);
auto conv2d_bias_md =
memory::desc({conv2d_bias_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);
auto conv2d_weights_md =
memory::desc({conv2d_weights_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);
auto conv2d_dst_md =
memory::desc({conv2d_dst_tz}, (pre_cast && post_cast) ? dt::f32 : dt::bf16, tag::any);

auto conv2d_desc = convolution_forward::desc(
prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md,
Expand Down Expand Up @@ -161,8 +172,8 @@ extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, i
primitive_attr attr;
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr,
false);
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr, false,
true, true);
}

primitive_attr create_attr_with_relu_post_op() {
Expand All @@ -182,7 +193,7 @@ extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out,
std::vector<float> bias(p_O_, 0);
return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_,
p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op(), false);
create_attr_with_relu_post_op(), false, true, true);
}

extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out,
Expand All @@ -192,7 +203,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float*
int p_Sw_) {
return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_,
p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
create_attr_with_relu_post_op(), false);
create_attr_with_relu_post_op(), false, true, true);
}

extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) {
Expand Down Expand Up @@ -345,6 +356,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body([](TVMArgs args, TVMRetV
int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], p_Sh_ = args[7],
p_Sw_ = args[8], p_G_ = args[9];
bool channel_last = args[10];
bool pre_cast = args[11];
bool post_cast = args[12];

int p_N_ = input->shape[0], p_C_ = input->shape[1], p_H_ = input->shape[2],
p_W_ = input->shape[3], p_O_ = output->shape[1], p_Kh_ = weights->shape[2],
Expand All @@ -365,7 +378,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body([](TVMArgs args, TVMRetV
return dnnl_conv2d_common(static_cast<float*>(input->data), static_cast<float*>(weights->data),
bias.data(), static_cast<float*>(output->data), p_N_, p_C_, p_H_, p_W_,
p_O_, p_G_, p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_,
attr, channel_last);
attr, channel_last, pre_cast, post_cast);
});

} // namespace contrib
Expand Down
140 changes: 88 additions & 52 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,6 +2003,22 @@ def test_conv2d_rocm_sdot4():
np.testing.assert_equal(out, ref)


def np_float2tvm_bf16(arr):
"""Convert a numpy array of float to a TVM array
of bf16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
nparr = np.right_shift(orig + bias, 16).astype("uint16")
return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)


def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")


@tvm.testing.requires_x86
def test_conv2d_nchw_dnnl():
if not tvm.get_global_func("tvm.contrib.dnnl.conv2d", allow_missing=True):
Expand All @@ -2016,39 +2032,49 @@ def test_conv2d_nchw_dnnl():
padding = (1, 1)
strides = (1, 1)

data = relay.var("data", shape=d_shape, dtype="float32")
weight = relay.var("weight", shape=w_shape, dtype="float32")
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype="float32",
)
def get_subgraph(dtype):
data = relay.var("data", shape=d_shape, dtype=dtype)
weight = relay.var("weight", shape=w_shape, dtype=dtype)
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=dtype,
)
return conv2d

mod = tvm.IRModule.from_expr(conv2d)
for t in ["float32", "bfloat16"]:
mod = tvm.IRModule.from_expr(get_subgraph(t))

data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding)

target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
if t == "bfloat16":
data_np = np_float2tvm_bf16(data_np)
weight_np = np_float2tvm_bf16(weight_np)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})

runtime.set_input("data", data_np)
runtime.run()
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

out = runtime.get_output(0).numpy()
runtime.set_input("data", data_np)
runtime.run()

ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding)
out = runtime.get_output(0).numpy()

np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
if t == "bfloat16":
out = np_bf162np_float(out)
np.testing.assert_allclose(out, ref, rtol=1e-2)
else:
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


@tvm.testing.requires_x86
Expand All @@ -2064,41 +2090,51 @@ def test_conv2d_nhwc_dnnl():
padding = (1, 1)
strides = (1, 1)

data = relay.var("data", shape=d_shape, dtype="float32")
weight = relay.var("weight", shape=w_shape, dtype="float32")
out_channel = w_shape[3]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[:2],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype="float32",
data_layout="NHWC",
kernel_layout="HWIO",
)
def get_subgraph(dtype):
data = relay.var("data", shape=d_shape, dtype=dtype)
weight = relay.var("weight", shape=w_shape, dtype=dtype)
out_channel = w_shape[3]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[:2],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=dtype,
data_layout="NHWC",
kernel_layout="HWIO",
)
return conv2d

mod = tvm.IRModule.from_expr(conv2d)
for t in ["float32", "bfloat16"]:
mod = tvm.IRModule.from_expr(get_subgraph(t))

data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding)

target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
if t == "bfloat16":
data_np = np_float2tvm_bf16(data_np)
weight_np = np_float2tvm_bf16(weight_np)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})

runtime.set_input("data", data_np)
runtime.run()
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

out = runtime.get_output(0).numpy()
runtime.set_input("data", data_np)
runtime.run()

ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding)
out = runtime.get_output(0).numpy()

np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
if t == "bfloat16":
out = np_bf162np_float(out)
np.testing.assert_allclose(out, ref, rtol=1e-2)
else:
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
Expand Down

0 comments on commit addfe2c

Please sign in to comment.