From 0564d38e7965152f94d61528020aa19fac064b0a Mon Sep 17 00:00:00 2001 From: sunway Date: Mon, 27 Sep 2021 03:47:16 +0800 Subject: [PATCH] add `multiply` and remove `subtract` for dnnl json runtime (#9120) --- python/tvm/relay/op/contrib/dnnl.py | 1 - src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 15 ++++--- tests/python/relay/test_json_runtime.py | 45 +++++++++++++++++++ 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 79bd02db164b..a2fdc19badab 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -64,7 +64,6 @@ def _func_wrapper(expr): _register_external_op_helper("nn.dense") _register_external_op_helper("nn.relu") _register_external_op_helper("add") -_register_external_op_helper("subtract") _register_external_op_helper("multiply") diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 66378d74f5d7..b32d137a2566 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -113,7 +113,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } else if ("nn.relu" == op_name) { Relu(nid); } else if ("add" == op_name) { - Add(nid); + Binary(nid, dnnl::algorithm::binary_add); + } else if ("multiply" == op_name) { + Binary(nid, dnnl::algorithm::binary_mul); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -355,7 +357,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); } - void Add(const size_t& nid) { + void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; // Memory and compute description. @@ -377,11 +379,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase { JSONGraphNodeEntry out_entry(nid, 0); auto out_memory = BindDNNLMemory(out_entry, out_md); - auto add_desc = - dnnl::binary::desc(dnnl::algorithm::binary_add, data_mds[0], data_mds[1], out_md); - auto add_prim_desc = dnnl::binary::primitive_desc(add_desc, engine_); - auto add = dnnl::binary(add_prim_desc); - net_.push_back(add); + auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); + auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); + auto binary = dnnl::binary(binary_prim_desc); + net_.push_back(binary); net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, {DNNL_ARG_SRC_1, data_memories[1]}, diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index 721271ac70f1..ca792204c835 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -216,6 +216,50 @@ def gen_add(): check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5) +def test_multiply(): + """Test a subgraph with a single add operator.""" + if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): + print("skip because DNNL codegen is not available") + return + + dtype = "float32" + shape = (10, 10) + + def gen_multiply(): + data0 = relay.var("data0", shape=shape, dtype=dtype) + data1 = relay.var("data1", shape=shape, dtype=dtype) + out = relay.multiply(data0, data1) + + func = relay.Function([data0, data1], out) + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") + mod = tvm.IRModule() + mod[glb_var] = func + mod = transform.InferType()(mod) + + data0 = relay.var("data0", shape=shape, dtype=dtype) + data1 = relay.var("data1", shape=shape, dtype=dtype) + main_f = relay.Function([data0, data1], glb_var(data0, data1)) + mod["main"] = main_f + mod = transform.InferType()(mod) + + data0 = relay.var("data0", shape=shape, dtype=dtype) + data1 = relay.var("data1", shape=shape, dtype=dtype) + out = relay.multiply(data0, data1) + main_f = relay.Function([data0, data1], out) + ref_mod = tvm.IRModule() + ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) + + return mod, ref_mod + + mod, ref_mod = gen_multiply() + + data0 = np.random.uniform(0, 1, shape).astype(dtype) + data1 = np.random.uniform(0, 1, shape).astype(dtype) + check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5) + + def test_relu(): """Test a subgraph with a single ReLU operator.""" if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): @@ -672,6 +716,7 @@ def test_partial_constant(): if __name__ == "__main__": test_conv2d() test_add() + test_multiply() test_relu() test_dense() test_bn()