Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] FTVMAnnotateTarget method signature update #6786

Merged
merged 1 commit into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,11 @@ using FTVMLegalize = runtime::TypedPackedFunc<Expr(const Attrs& attrs, const Arr
/*!
* \brief Annotates an expression to indicate if an op should be compiled using
* the given compiler/target.
*
* \param attrs The attribute of the original expr.
* \param args The arguments of the original expr.
*
* \param expr The original expr.
* \return true if this op should be registered to invoke a specific compiler
* for codegen, otherwise, false.
*/
using FTVMAnnotateTarget = runtime::TypedPackedFunc<bool(const Attrs& attrs, // NOLINT(*)
const Array<Expr>& args)>;
using FTVMAnnotateTarget = runtime::TypedPackedFunc<bool(const Expr& expr)>;

/*!
* \brief Forward rewriting rule for a specific op.
Expand Down
47 changes: 29 additions & 18 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_conv(extract):
call = extract
while call.op.name != "nn.conv2d":
call = call.args[0]
return conv2d(call.attrs, call.args)
return conv2d(call)

def check_qnn_conv(extract):
"""Check qnn conv pattern is supported by ACL."""
Expand All @@ -176,14 +176,14 @@ def check_qnn_conv(extract):
call = extract
while call.op.name != "qnn.conv2d":
call = call.args[0]
return qnn_conv2d(call.attrs, call.args)
return qnn_conv2d(call)

def check_dense(extract):
"""Check conv pattern is supported by ACL."""
call = extract
while call.op.name != "nn.dense":
call = call.args[0]
return dense(call.attrs, call.args)
return dense(call)

def check_qnn_dense(extract):
"""Check qnn conv pattern is supported by ACL."""
Expand All @@ -192,7 +192,7 @@ def check_qnn_dense(extract):
call = extract
while call.op.name != "qnn.dense":
call = call.args[0]
return qnn_dense(call.attrs, call.args)
return qnn_dense(call)

def check_avg_pool2d(extract):
"""Check average pool2d pattern is supported by ACL."""
Expand All @@ -201,12 +201,12 @@ def check_avg_pool2d(extract):
pool = extract.args[0]
if pool.args[0].attrs.dtype != "int32":
return False
return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True)
return avg_pool2d(pool, from_quantized_composite=True)

def check_l2_pool2d(extract):
"""Check l2 pool2d pattern is supported by ACL."""
pool = extract.args[0]
return avg_pool2d(pool.attrs, pool.args)
return avg_pool2d(pool)

return [
("arm_compute_lib.conv2d", conv_pattern(), check_conv),
Expand All @@ -221,7 +221,7 @@ def check_l2_pool2d(extract):

def _register_external_op_helper(op_name, supported=True):
@tvm.ir.register_op_attr(op_name, "target.arm_compute_lib")
def _func_wrapper(attrs, args):
def _func_wrapper(expr):
return supported

return _func_wrapper
Expand All @@ -231,8 +231,9 @@ def _func_wrapper(attrs, args):


@tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib")
def conv2d(attrs, args):
def conv2d(expr):
"""Check if the external ACL codegen for conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.groups != 1:
return False
if attrs.data_layout != "NHWC":
Expand All @@ -248,8 +249,9 @@ def conv2d(attrs, args):
return True


def qnn_conv2d(attrs, args):
def qnn_conv2d(expr):
"""Check if the external ACL codegen for qnn.conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.groups != 1:
return False
if attrs.data_layout != "NHWC":
Expand All @@ -266,8 +268,9 @@ def qnn_conv2d(attrs, args):


@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib")
def dense(attrs, args):
def dense(expr):
"""Check if the external ACL codegen for dense should be used."""
attrs, args = expr.attrs, expr.args
data_typ = args[0].checked_type
if data_typ.dtype != "float32":
return False
Expand All @@ -279,8 +282,9 @@ def dense(attrs, args):
return True


def qnn_dense(attrs, args):
def qnn_dense(expr):
"""Check if the external ACL codegen for qnn.dense should be used."""
attrs, args = expr.attrs, expr.args
data_typ = args[0].checked_type
if data_typ.dtype != "uint8":
return False
Expand All @@ -293,8 +297,9 @@ def qnn_dense(attrs, args):


@tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib")
def max_pool2d(attrs, args):
def max_pool2d(expr):
"""Check if the external ACL codegen for maxpool2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.layout != "NHWC":
return False
typ = args[0].checked_type
Expand All @@ -304,8 +309,9 @@ def max_pool2d(attrs, args):


@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib")
def avg_pool2d(attrs, args, from_quantized_composite=False):
def avg_pool2d(expr, from_quantized_composite=False):
"""Check if the external ACL codegen for avgpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type
if from_quantized_composite:
if typ.dtype != "int32":
Expand All @@ -319,8 +325,9 @@ def avg_pool2d(attrs, args, from_quantized_composite=False):


@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib")
def global_max_pool2d(attrs, args):
def global_max_pool2d(expr):
"""Check if the external ACL codegen for gloval_maxpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type
if typ.dtype not in ["float32", "uint8"]:
return False
Expand All @@ -330,8 +337,9 @@ def global_max_pool2d(attrs, args):


@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib")
def global_avg_pool2d(attrs, args):
def global_avg_pool2d(expr):
"""Check if the external ACL codegen for global_avgpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type
if typ.dtype not in ["float32"]:
return False
Expand All @@ -341,16 +349,18 @@ def global_avg_pool2d(attrs, args):


@tvm.ir.register_op_attr("maximum", "target.arm_compute_lib")
def maximum(attrs, args):
def maximum(expr):
"""Check if the external ACL codegen for maximum should be used."""
args = expr.args
type_a = args[0].checked_type
type_b = args[0].checked_type
return (type_a.dtype == "float32") and (type_b.dtype == "float32")


@tvm.ir.register_op_attr("add", "target.arm_compute_lib")
def add(attrs, args):
def add(expr):
"""Check if the external ACL codegen for add should be used."""
args = expr.args
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype != "float32":
return False
Expand All @@ -359,8 +369,9 @@ def add(attrs, args):


@tvm.ir.register_op_attr("qnn.add", "target.arm_compute_lib")
def qnn_add(attrs, args):
def qnn_add(expr):
"""Check if the external ACL codegen for add should be used."""
args = expr.args
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype != "uint8":
return False
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/contrib/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _register_coreml_op(op_name):

"""

def _check_supported(attrs, args):
def _check_supported(expr):
attrs, args = expr.attrs, expr.args
if op_name == "nn.conv2d":
if not isinstance(args[1], Constant):
return False
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _register_external_op_helper(op_name, supported=True):
"""

@tvm.ir.register_op_attr(op_name, "target.dnnl")
def _func_wrapper(attrs, args):
def _func_wrapper(expr):
return supported

return _func_wrapper
Expand Down
21 changes: 14 additions & 7 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,23 @@ def _is_ethosn_composite(node):


@tvm.ir.register_op_attr("nn.max_pool2d", "target.ethos-n")
def max_pool2d(attrs, args):
def max_pool2d(expr):
"""Check if a max pool2d is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
pool = tvm.relay.nn.max_pool2d(*args, **attrs)
return support.max_pool2d(pool)


@tvm.ir.register_op_attr("reshape", "target.ethos-n")
def reshape(attrs, args):
def reshape(expr):
"""Check if a reshape is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
if not _is_ethosn_composite(args[0]):
return False

Expand All @@ -151,21 +153,23 @@ def reshape(attrs, args):


@tvm.ir.register_op_attr("qnn.add", "target.ethos-n")
def qnn_add(attrs, args):
def qnn_add(expr):
"""Check if an addition is supported by Ethos-N."""
if not ethosn_available():
return False

args = expr.args
add = _qnn.op.add(*args)
return support.addition(add)


@tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n")
def qnn_concatenate(attrs, args):
def qnn_concatenate(expr):
"""Check if a concatenate is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
conc = _qnn.op.concatenate(*args, **attrs)
if not support.concatenate(conc):
return False
Expand All @@ -190,11 +194,12 @@ def qnn_concatenate(attrs, args):


@tvm.ir.register_op_attr("split", "target.ethos-n")
def split(attrs, args):
def split(expr):
"""Check if a split is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm):
sp = tvm.relay.split(
*args, indices_or_sections=attrs["indices_or_sections"].value, axis=attrs["axis"]
Expand All @@ -210,11 +215,12 @@ def split(attrs, args):


@tvm.ir.register_op_attr("nn.depth_to_space", "target.ethos-n")
def depth_to_space(attrs, args):
def depth_to_space(expr):
"""Check if a depth_to_space is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
depth = tvm.relay.nn.depth_to_space(*args, **attrs)
if not support.depth_to_space(depth):
return False
Expand All @@ -223,11 +229,12 @@ def depth_to_space(attrs, args):


@tvm.ir.register_op_attr("clip", "target.ethos-n")
def clip(attrs, args):
def clip(expr):
"""Check if a clip is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
c = tvm.relay.clip(*args, **attrs)
if not support.relu(c):
return False
Expand Down
Loading