Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[BYOC-DNNL] support more post-ops (apache#12002)
Browse files Browse the repository at this point in the history
* support post-op swish

* support post-op clip

* enhance get_shape and get_dtype in dnnl.py to support efficientnet

* add checks for with_eltwise whether in supported list

* fix lint

* fix test
  • Loading branch information
crazydemo authored and xinetzone committed Nov 25, 2022
1 parent 30d80ff commit acc7458
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 104 deletions.
22 changes: 19 additions & 3 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@


logger = logging.getLogger("DNNL")
supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]


def _register_external_op_helper(op_name, supported=True):
Expand Down Expand Up @@ -120,6 +121,8 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out : CallPattern
Call node sequence.
"""
if with_eltwise not in supported_post_elts:
raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
data = wildcard()
weight = wildcard()
bias = wildcard()
Expand All @@ -128,8 +131,11 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
if with_eltwise:
return is_op(with_eltwise)(conv_out)
if with_eltwise == "swish":
sig_out = is_op("sigmoid")(conv_out)
conv_out = is_op("multiply")(conv_out, sig_out)
elif with_eltwise:
conv_out = is_op(with_eltwise)(conv_out)
return conv_out


Expand All @@ -147,6 +153,8 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
dense_out : CallPattern
Call node sequence.
"""
if with_eltwise not in supported_post_elts:
raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise)
data = wildcard()
weight = wildcard()
bias = wildcard()
Expand All @@ -165,6 +173,9 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
added_erf_val = is_op("add")(erf_val, const2)
mul_val = is_op("multiply")(dense_out, added_erf_val)
dense_out = is_op("multiply")(mul_val, const3)
elif with_eltwise == "swish":
sig_out = is_op("sigmoid")(dense_out)
dense_out = is_op("multiply")(dense_out, sig_out)
elif with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out
Expand All @@ -191,6 +202,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
pat_name = pat_name.replace("_swish", "_sigmoid_mul")
if "conv" in op_name:
dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
elif op_name == "nn.dense":
Expand Down Expand Up @@ -282,7 +294,7 @@ def pattern_table():
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())

elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
Expand Down Expand Up @@ -380,6 +392,8 @@ def get_shape(tensor):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].shape
if isinstance(tensor, relay.expr.Call):
if tensor.op.name == "multiply":
return tensor.type_args[0].shape
return tensor.checked_type.shape
raise TypeError("Unsupport data type: %s" % type(tensor))

Expand All @@ -395,6 +409,8 @@ def get_dtype(tensor):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].dtype
if isinstance(tensor, relay.expr.Call):
if tensor.op.name == "multiply":
return tensor.type_args[0].dtype
return tensor.checked_type.dtype
raise TypeError("Unsupport data type: %s" % type(tensor))

Expand Down
9 changes: 9 additions & 0 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
{"relu", "nn.relu"},
{"tanh", "tanh"},
{"sigmoid", "sigmoid"},
{"clip", "clip"},
{"mul", "multiply"},
{"nn.deconv2d", "nn.conv2d_transpose"},
{"nn.deconv3d", "nn.conv3d_transpose"},
};
Expand Down Expand Up @@ -566,6 +568,13 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);
SetCallNodeAttribute(node, call);
// If has post-op `clip`. Assume the last op is clip, add clip's attrs to the pattern attrs.
if (name.find("_clip") != std::string::npos) {
auto clip_call = cn->op.as<FunctionNode>()->body.as<CallNode>();
ICHECK(IsOp(clip_call, "clip"));
SetCallNodeAttribute(node, clip_call);
}
// For QNN.
for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second);

return AddNode(node, GetRef<Expr>(cn));
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
current_call->args[valid_node_idx].as<VarNode>()) {
valid_node_idx++;
}
while (valid_node_idx < current_call->args.size() &&
!(IsOp(current_call->args[valid_node_idx].as<CallNode>(), expected_op_names[depth - 1]))) {
valid_node_idx++;
}
const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
return GetRootCall(next_call, depth - 1, expected_op_names);
}
Expand Down
12 changes: 11 additions & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");

// Parsing post-ops.
Expand All @@ -199,8 +200,17 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, tanh_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
if (std::regex_match(op_name, clip_pat)) {
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (std::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
if (op_name.find("_sigmoid_mul") != std::string::npos) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
} else {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
}
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
Expand Down
126 changes: 31 additions & 95 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te
if use_dnnl:
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
check_dnnl_used(processed_mod)

with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=processed_mod, device=dev, target=target
Expand Down Expand Up @@ -237,6 +236,23 @@ def run_and_verify_func(
)


def add_activation(activation, out, dic, param_lst):
if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
elif activation == "clip":
return relay.clip(out, 0.0, 6.0), dic, param_lst
elif activation == "swish":
sig_out = relay.sigmoid(out)
out = relay.multiply(out, sig_out)
return out, dic, param_lst
else:
return out, dic, param_lst


def get_conv1d(
x_shape=((1, 3, 224)),
k_shape=(16, 3, 3),
Expand All @@ -262,15 +278,7 @@ def get_conv1d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"):
Expand All @@ -279,15 +287,7 @@ def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dt
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"):
Expand Down Expand Up @@ -334,15 +334,7 @@ def get_conv2d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_transpose(
Expand All @@ -367,15 +359,7 @@ def get_conv2d_transpose(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_weights_const(
Expand Down Expand Up @@ -412,15 +396,7 @@ def get_conv2d_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_transpose_bias(
Expand All @@ -431,15 +407,7 @@ def get_conv2d_transpose_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[1],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"):
Expand Down Expand Up @@ -503,15 +471,7 @@ def get_conv3d(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv3d_transpose(
Expand Down Expand Up @@ -542,15 +502,7 @@ def get_conv3d_transpose(
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv3d_bias(
Expand All @@ -561,15 +513,7 @@ def get_conv3d_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[0],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def get_conv3d_transpose_bias(
Expand All @@ -580,15 +524,7 @@ def get_conv3d_transpose_bias(
out = relay.nn.bias_add(conv, bias)
dic["bias"] = (k_shape[1],)
param_lst += ["bias"]

if activation == "relu":
return relay.nn.relu(out), dic, param_lst
elif activation == "tanh":
return relay.tanh(out), dic, param_lst
elif activation == "sigmoid":
return relay.sigmoid(out), dic, param_lst
else:
return out, dic, param_lst
return add_activation(activation, out, dic, param_lst)


def gelu_helper(data):
Expand Down Expand Up @@ -797,7 +733,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"):
def test_conv2d_pattern(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
Expand Down Expand Up @@ -839,7 +775,7 @@ def test_conv2d_transpose(run_module, dtype="float32"):


def test_conv2d_transpose_pattern(run_module, dtype="float32"):
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype)
conv2d = tvm.IRModule.from_expr(conv2d)
Expand Down Expand Up @@ -872,7 +808,7 @@ def test_conv3d(run_module, dtype="float32"):


def test_conv3d_pattern(run_module, dtype="float32"):
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
Expand Down Expand Up @@ -905,7 +841,7 @@ def test_conv3d_transpose(run_module, dtype="float32"):


def test_conv3d_transpose_pattern(run_module, dtype="float32"):
activation_lst = [None, "relu", "tanh", "sigmoid"]
activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"]
for a in activation_lst:
conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype)
conv3d = tvm.IRModule.from_expr(conv3d)
Expand Down
Loading

0 comments on commit acc7458

Please sign in to comment.