Skip to content

Commit

Permalink
[BYOC] FTVMAnnotateTarget method signature update (apache#6786)
Browse files Browse the repository at this point in the history
Signature of FTVMAnnotateTarget changed to runtime::TypedPackedFunc<bool(const Expr& expr)>
which allows to utilise extra information from passed expr argument.
  • Loading branch information
d-smirnov authored and trevor-m committed Dec 4, 2020
1 parent 6f9f443 commit 92d13a5
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 83 deletions.
8 changes: 2 additions & 6 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,11 @@ using FTVMLegalize = runtime::TypedPackedFunc<Expr(const Attrs& attrs, const Arr
/*!
* \brief Annotates an expression to indicate if an op should be compiled using
* the given compiler/target.
*
* \param attrs The attribute of the original expr.
* \param args The arguments of the original expr.
*
* \param expr The original expr.
* \return true if this op should be registered to invoke a specific compiler
* for codegen, otherwise, false.
*/
using FTVMAnnotateTarget = runtime::TypedPackedFunc<bool(const Attrs& attrs, // NOLINT(*)
const Array<Expr>& args)>;
using FTVMAnnotateTarget = runtime::TypedPackedFunc<bool(const Expr& expr)>;

/*!
* \brief Forward rewriting rule for a specific op.
Expand Down
47 changes: 29 additions & 18 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_conv(extract):
call = extract
while call.op.name != "nn.conv2d":
call = call.args[0]
return conv2d(call.attrs, call.args)
return conv2d(call)

def check_qnn_conv(extract):
"""Check qnn conv pattern is supported by ACL."""
Expand All @@ -176,14 +176,14 @@ def check_qnn_conv(extract):
call = extract
while call.op.name != "qnn.conv2d":
call = call.args[0]
return qnn_conv2d(call.attrs, call.args)
return qnn_conv2d(call)

def check_dense(extract):
"""Check conv pattern is supported by ACL."""
call = extract
while call.op.name != "nn.dense":
call = call.args[0]
return dense(call.attrs, call.args)
return dense(call)

def check_qnn_dense(extract):
"""Check qnn conv pattern is supported by ACL."""
Expand All @@ -192,7 +192,7 @@ def check_qnn_dense(extract):
call = extract
while call.op.name != "qnn.dense":
call = call.args[0]
return qnn_dense(call.attrs, call.args)
return qnn_dense(call)

def check_avg_pool2d(extract):
"""Check average pool2d pattern is supported by ACL."""
Expand All @@ -201,12 +201,12 @@ def check_avg_pool2d(extract):
pool = extract.args[0]
if pool.args[0].attrs.dtype != "int32":
return False
return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True)
return avg_pool2d(pool, from_quantized_composite=True)

def check_l2_pool2d(extract):
"""Check l2 pool2d pattern is supported by ACL."""
pool = extract.args[0]
return avg_pool2d(pool.attrs, pool.args)
return avg_pool2d(pool)

return [
("arm_compute_lib.conv2d", conv_pattern(), check_conv),
Expand All @@ -221,7 +221,7 @@ def check_l2_pool2d(extract):

def _register_external_op_helper(op_name, supported=True):
@tvm.ir.register_op_attr(op_name, "target.arm_compute_lib")
def _func_wrapper(attrs, args):
def _func_wrapper(expr):
return supported

return _func_wrapper
Expand All @@ -231,8 +231,9 @@ def _func_wrapper(attrs, args):


@tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib")
def conv2d(attrs, args):
def conv2d(expr):
"""Check if the external ACL codegen for conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.groups != 1:
return False
if attrs.data_layout != "NHWC":
Expand All @@ -248,8 +249,9 @@ def conv2d(attrs, args):
return True


def qnn_conv2d(attrs, args):
def qnn_conv2d(expr):
"""Check if the external ACL codegen for qnn.conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.groups != 1:
return False
if attrs.data_layout != "NHWC":
Expand All @@ -266,8 +268,9 @@ def qnn_conv2d(attrs, args):


@tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib")
def dense(attrs, args):
def dense(expr):
"""Check if the external ACL codegen for dense should be used."""
attrs, args = expr.attrs, expr.args
data_typ = args[0].checked_type
if data_typ.dtype != "float32":
return False
Expand All @@ -279,8 +282,9 @@ def dense(attrs, args):
return True


def qnn_dense(attrs, args):
def qnn_dense(expr):
"""Check if the external ACL codegen for qnn.dense should be used."""
attrs, args = expr.attrs, expr.args
data_typ = args[0].checked_type
if data_typ.dtype != "uint8":
return False
Expand All @@ -293,8 +297,9 @@ def qnn_dense(attrs, args):


@tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib")
def max_pool2d(attrs, args):
def max_pool2d(expr):
"""Check if the external ACL codegen for maxpool2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.layout != "NHWC":
return False
typ = args[0].checked_type
Expand All @@ -304,8 +309,9 @@ def max_pool2d(attrs, args):


@tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib")
def avg_pool2d(attrs, args, from_quantized_composite=False):
def avg_pool2d(expr, from_quantized_composite=False):
"""Check if the external ACL codegen for avgpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type
if from_quantized_composite:
if typ.dtype != "int32":
Expand All @@ -319,8 +325,9 @@ def avg_pool2d(attrs, args, from_quantized_composite=False):


@tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib")
def global_max_pool2d(attrs, args):
def global_max_pool2d(expr):
"""Check if the external ACL codegen for gloval_maxpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type
if typ.dtype not in ["float32", "uint8"]:
return False
Expand All @@ -330,8 +337,9 @@ def global_max_pool2d(attrs, args):


@tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib")
def global_avg_pool2d(attrs, args):
def global_avg_pool2d(expr):
"""Check if the external ACL codegen for global_avgpool2d should be used."""
attrs, args = expr.attrs, expr.args
typ = args[0].checked_type
if typ.dtype not in ["float32"]:
return False
Expand All @@ -341,16 +349,18 @@ def global_avg_pool2d(attrs, args):


@tvm.ir.register_op_attr("maximum", "target.arm_compute_lib")
def maximum(attrs, args):
def maximum(expr):
"""Check if the external ACL codegen for maximum should be used."""
args = expr.args
type_a = args[0].checked_type
type_b = args[0].checked_type
return (type_a.dtype == "float32") and (type_b.dtype == "float32")


@tvm.ir.register_op_attr("add", "target.arm_compute_lib")
def add(attrs, args):
def add(expr):
"""Check if the external ACL codegen for add should be used."""
args = expr.args
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype != "float32":
return False
Expand All @@ -359,8 +369,9 @@ def add(attrs, args):


@tvm.ir.register_op_attr("qnn.add", "target.arm_compute_lib")
def qnn_add(attrs, args):
def qnn_add(expr):
"""Check if the external ACL codegen for add should be used."""
args = expr.args
for typ in [args[0].checked_type, args[1].checked_type]:
if typ.dtype != "uint8":
return False
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/op/contrib/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _register_coreml_op(op_name):
"""

def _check_supported(attrs, args):
def _check_supported(expr):
attrs, args = expr.attrs, expr.args
if op_name == "nn.conv2d":
if not isinstance(args[1], Constant):
return False
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _register_external_op_helper(op_name, supported=True):
"""

@tvm.ir.register_op_attr(op_name, "target.dnnl")
def _func_wrapper(attrs, args):
def _func_wrapper(expr):
return supported

return _func_wrapper
Expand Down
21 changes: 14 additions & 7 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,23 @@ def _is_ethosn_composite(node):


@tvm.ir.register_op_attr("nn.max_pool2d", "target.ethos-n")
def max_pool2d(attrs, args):
def max_pool2d(expr):
"""Check if a max pool2d is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
pool = tvm.relay.nn.max_pool2d(*args, **attrs)
return support.max_pool2d(pool)


@tvm.ir.register_op_attr("reshape", "target.ethos-n")
def reshape(attrs, args):
def reshape(expr):
"""Check if a reshape is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
if not _is_ethosn_composite(args[0]):
return False

Expand All @@ -151,21 +153,23 @@ def reshape(attrs, args):


@tvm.ir.register_op_attr("qnn.add", "target.ethos-n")
def qnn_add(attrs, args):
def qnn_add(expr):
"""Check if an addition is supported by Ethos-N."""
if not ethosn_available():
return False

args = expr.args
add = _qnn.op.add(*args)
return support.addition(add)


@tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n")
def qnn_concatenate(attrs, args):
def qnn_concatenate(expr):
"""Check if a concatenate is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
conc = _qnn.op.concatenate(*args, **attrs)
if not support.concatenate(conc):
return False
Expand All @@ -190,11 +194,12 @@ def qnn_concatenate(attrs, args):


@tvm.ir.register_op_attr("split", "target.ethos-n")
def split(attrs, args):
def split(expr):
"""Check if a split is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm):
sp = tvm.relay.split(
*args, indices_or_sections=attrs["indices_or_sections"].value, axis=attrs["axis"]
Expand All @@ -210,11 +215,12 @@ def split(attrs, args):


@tvm.ir.register_op_attr("nn.depth_to_space", "target.ethos-n")
def depth_to_space(attrs, args):
def depth_to_space(expr):
"""Check if a depth_to_space is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
depth = tvm.relay.nn.depth_to_space(*args, **attrs)
if not support.depth_to_space(depth):
return False
Expand All @@ -223,11 +229,12 @@ def depth_to_space(attrs, args):


@tvm.ir.register_op_attr("clip", "target.ethos-n")
def clip(attrs, args):
def clip(expr):
"""Check if a clip is supported by Ethos-N."""
if not ethosn_available():
return False

attrs, args = expr.attrs, expr.args
c = tvm.relay.clip(*args, **attrs)
if not support.relu(c):
return False
Expand Down
Loading

0 comments on commit 92d13a5

Please sign in to comment.