Skip to content

Commit

Permalink
[BYOC][DNNL] Improve performance of DNNL BYOC dense operator (#11513)
Browse files Browse the repository at this point in the history
* Enhance dnnl byoc dense operators performance by 1) introducing gelu fusion and 2) introducing alter dense weight layout.

* fix lint issue

* add unittest for dense pack

* Make code compatible after introducing TensorRequisite(PR-11345)

* Fix comments & refactor code

* Fix lint

* Fix partition graph unittest case

* Fix comments

* Fix comments

* Fix lint
  • Loading branch information
billishyahao authored Jun 10, 2022
1 parent dc522a6 commit e8712a9
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 27 deletions.
123 changes: 118 additions & 5 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@
from tvm.relay.expr import GlobalVar
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr


from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table


logger = logging.getLogger("DNNL")


Expand Down Expand Up @@ -139,12 +144,22 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
data = wildcard()
weight = wildcard()
bias = wildcard()

dense = is_op("nn.dense")(data, weight)
if with_bias:
dense_out = is_op("add")(dense, bias)
else:
dense_out = dense
if with_eltwise:
if with_eltwise == "gelu":
const1 = wildcard()
const2 = wildcard()
const3 = wildcard()
div = is_op("divide")(dense_out, const1)
erf_val = is_op("erf")(div)
added_erf_val = is_op("add")(erf_val, const2)
mul_val = is_op("multiply")(dense_out, added_erf_val)
dense_out = is_op("multiply")(mul_val, const3)
elif with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out

Expand Down Expand Up @@ -176,7 +191,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning(
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose, "
"dense op are supported, but got %s.",
op_name,
)
Expand All @@ -193,20 +208,21 @@ def pattern_table():
dnnl_patterns : List[dnnl_pattern]
Created patterns.
"""
elt_list = ["nn.relu", "tanh", "sigmoid", None]
elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
dnnl_patterns = []
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
return dnnl_patterns
continue
for conv_name in [
"nn.conv1d",
"nn.conv2d",
"nn.conv3d",
"nn.conv2d_transpose",
"nn.conv3d_transpose",
]:
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
if elt != "gelu":
dnnl_patterns.append(make_dnnl_pattern(conv_name, with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("nn.dense", with_bias, elt))
return dnnl_patterns

Expand Down Expand Up @@ -339,6 +355,7 @@ def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
res += i
else:
raise ValueError("Unsupport layout format: %s" % input_data)

return res


Expand Down Expand Up @@ -594,3 +611,99 @@ def rewrite_layer_norm(mod):
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod


class DenseReshapeBiasGeluRewrite(DFPatternCallback):
"""
A callback to reorder reshape operators when the patterns are as below:
Pattern #1:
1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 64), float32] */;
2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64), float32] */;
3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
/* ty=Tensor[(1, 3136, 64), float32] */;
Pattern #2:
1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64), float32] */,
units=None, out_dtype="float32") /* ty=Tensor[(3136, 512), float32] */;
2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136, 512), float32] */;
3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */, %77)
/* ty=Tensor[(1, 3136, 512), float32] */;
4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512), float32] */;
"""

def __init__(self, has_gelu=True):
super(DenseReshapeBiasGeluRewrite, self).__init__()
self.data = wildcard()
self.weight = wildcard()
self.bias = wildcard()
self.const1 = wildcard()
self.const2 = wildcard()
self.const3 = wildcard()

self.attr_map = {}
self.has_gelu = has_gelu

den = is_op("nn.dense")(self.data, self.weight)
re_den = is_op("reshape")(den)
added = is_op("add")(self.bias, re_den)
if self.has_gelu:
divisor = is_op("divide")(added, self.const1)
val_erf = is_op("erf")(divisor)
added_erf = is_op("add")(val_erf, self.const2)
mul1 = is_op("multiply")(added, added_erf)
mul2 = is_op("multiply")(mul1, self.const3)
self.pattern = mul2
else:
self.pattern = added

def get_attr(self, pre):
"""Recursively retrieve attributes from reshape operator."""

def visit_func(expr):
if isinstance(expr, _expr.Call) and expr.op == relay.op.get("reshape"):
new_attrs = {}
for k in expr.attrs.keys():
new_attrs[k] = expr.attrs[k]
self.attr_map["reshape"] = new_attrs

_analysis.post_order_visit(pre, visit_func)

def callback(self, pre, post, node_map):
self.get_attr(pre)

data = node_map[self.data][0]
weight = node_map[self.weight][0]
bias = node_map[self.bias][0]

den = relay.op.nn.dense(data, weight)
added = relay.op.add(bias, den)
if not self.has_gelu:
return relay.op.reshape(added, self.attr_map["reshape"]["newshape"])

const1 = node_map[self.const1][0]
const2 = node_map[self.const2][0]
const3 = node_map[self.const3][0]

divisor = relay.op.divide(added, const1)
val_erf = relay.op.erf(divisor)
added_erf = relay.op.add(val_erf, const2)
mul1 = relay.op.multiply(added, added_erf)
mul2 = relay.op.multiply(mul1, const3)
return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"])


def rewrite_dense_bias_gelu_reshape_last(mod):
"""Rewrite the input graph to reorder reshape operators so that
we can perform dense_bias_gelu/dense_bias fusion and then offload
them to byoc part.
"""
mod["main"] = rewrite(
[DenseReshapeBiasGeluRewrite(), DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"]
)
return mod
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK_NE(pattern_name, "");
std::vector<std::string> op_list;
size_t pos = 0, start = 0;

while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
std::string op_name = pattern_name.substr(start, pos - start);
if (op_name.find("dnnl") != std::string::npos) {
Expand Down Expand Up @@ -508,8 +509,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.dense") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
call = GetRootCall(fn->body.as<CallNode>(), 10, "nn.dense");
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else {
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
Expand Down
32 changes: 32 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,38 @@ inline const CallNode* GetRootCall(const CallNode* current_call, const std::stri
return GetRootCall(next_call, op_name);
}

/*!
* \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in
* relu(add(conv2d))
* \param call A Relay call node. Typically nn.relu when called the first time.
* \param max_depth The maximum number of calls before the root op, counting from current_call.
* \param op_name The name of expected "root" op in this fused call.
* \return A CallNode corresponding to the root op
*/
inline const CallNode* GetRootCall(const CallNode* current_call, int max_depth,
const std::string& op_name) {
ICHECK(current_call && max_depth >= 0);

if (max_depth == 0) {
ICHECK(current_call && IsOp(current_call, op_name));
return current_call;
}
if (IsOp(current_call, op_name)) {
return current_call;
}

ICHECK_GT(current_call->args.size(), 0);

size_t valid_node_idx = 0;
while (valid_node_idx < current_call->args.size() &&
current_call->args[valid_node_idx].as<VarNode>()) {
valid_node_idx++;
}

const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
return GetRootCall(next_call, max_depth - 1, op_name);
}

/*!
* \brief Get the external symbol of the Relay function name.
*
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Find proper dnnl::memory buffers
std::unordered_map<int, dnnl::memory> mem_args;
for (const auto& kvp : arg_reqs) mem_args[kvp.first] = mem_solver(kvp.second);

prim.execute(stream_, mem_args);
}
}
Expand Down Expand Up @@ -143,6 +142,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex gelu_pat(".*_gelu.*");

// Parsing post-ops.
dnnl::post_ops ops;
Expand All @@ -155,7 +155,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (std::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
attr.set_post_ops(ops);
if (std::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
return std::regex_match(op_name, bias_add_pat) ? true : false;
Expand Down
Loading

0 comments on commit e8712a9

Please sign in to comment.