diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index acd4a03aed03..1e9b86d9e0bc 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -175,15 +175,11 @@ using FTVMLegalize = runtime::TypedPackedFunc& args)>; +using FTVMAnnotateTarget = runtime::TypedPackedFunc; /*! * \brief Forward rewriting rule for a specific op. diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index 8dfb3b7e0bf4..80d64db693ce 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -167,7 +167,7 @@ def check_conv(extract): call = extract while call.op.name != "nn.conv2d": call = call.args[0] - return conv2d(call.attrs, call.args) + return conv2d(call) def check_qnn_conv(extract): """Check qnn conv pattern is supported by ACL.""" @@ -176,14 +176,14 @@ def check_qnn_conv(extract): call = extract while call.op.name != "qnn.conv2d": call = call.args[0] - return qnn_conv2d(call.attrs, call.args) + return qnn_conv2d(call) def check_dense(extract): """Check conv pattern is supported by ACL.""" call = extract while call.op.name != "nn.dense": call = call.args[0] - return dense(call.attrs, call.args) + return dense(call) def check_qnn_dense(extract): """Check qnn conv pattern is supported by ACL.""" @@ -192,7 +192,7 @@ def check_qnn_dense(extract): call = extract while call.op.name != "qnn.dense": call = call.args[0] - return qnn_dense(call.attrs, call.args) + return qnn_dense(call) def check_avg_pool2d(extract): """Check average pool2d pattern is supported by ACL.""" @@ -201,12 +201,12 @@ def check_avg_pool2d(extract): pool = extract.args[0] if pool.args[0].attrs.dtype != "int32": return False - return avg_pool2d(pool.attrs, pool.args, from_quantized_composite=True) + return avg_pool2d(pool, from_quantized_composite=True) def check_l2_pool2d(extract): """Check l2 pool2d pattern is supported by ACL.""" pool = extract.args[0] - return avg_pool2d(pool.attrs, pool.args) + return avg_pool2d(pool) return [ ("arm_compute_lib.conv2d", conv_pattern(), check_conv), @@ -221,7 +221,7 @@ def check_l2_pool2d(extract): def _register_external_op_helper(op_name, supported=True): @tvm.ir.register_op_attr(op_name, "target.arm_compute_lib") - def _func_wrapper(attrs, args): + def _func_wrapper(expr): return supported return _func_wrapper @@ -231,8 +231,9 @@ def _func_wrapper(attrs, args): @tvm.ir.register_op_attr("nn.conv2d", "target.arm_compute_lib") -def conv2d(attrs, args): +def conv2d(expr): """Check if the external ACL codegen for conv2d should be used.""" + attrs, args = expr.attrs, expr.args if attrs.groups != 1: return False if attrs.data_layout != "NHWC": @@ -248,8 +249,9 @@ def conv2d(attrs, args): return True -def qnn_conv2d(attrs, args): +def qnn_conv2d(expr): """Check if the external ACL codegen for qnn.conv2d should be used.""" + attrs, args = expr.attrs, expr.args if attrs.groups != 1: return False if attrs.data_layout != "NHWC": @@ -266,8 +268,9 @@ def qnn_conv2d(attrs, args): @tvm.ir.register_op_attr("nn.dense", "target.arm_compute_lib") -def dense(attrs, args): +def dense(expr): """Check if the external ACL codegen for dense should be used.""" + attrs, args = expr.attrs, expr.args data_typ = args[0].checked_type if data_typ.dtype != "float32": return False @@ -279,8 +282,9 @@ def dense(attrs, args): return True -def qnn_dense(attrs, args): +def qnn_dense(expr): """Check if the external ACL codegen for qnn.dense should be used.""" + attrs, args = expr.attrs, expr.args data_typ = args[0].checked_type if data_typ.dtype != "uint8": return False @@ -293,8 +297,9 @@ def qnn_dense(attrs, args): @tvm.ir.register_op_attr("nn.max_pool2d", "target.arm_compute_lib") -def max_pool2d(attrs, args): +def max_pool2d(expr): """Check if the external ACL codegen for maxpool2d should be used.""" + attrs, args = expr.attrs, expr.args if attrs.layout != "NHWC": return False typ = args[0].checked_type @@ -304,8 +309,9 @@ def max_pool2d(attrs, args): @tvm.ir.register_op_attr("nn.avg_pool2d", "target.arm_compute_lib") -def avg_pool2d(attrs, args, from_quantized_composite=False): +def avg_pool2d(expr, from_quantized_composite=False): """Check if the external ACL codegen for avgpool2d should be used.""" + attrs, args = expr.attrs, expr.args typ = args[0].checked_type if from_quantized_composite: if typ.dtype != "int32": @@ -319,8 +325,9 @@ def avg_pool2d(attrs, args, from_quantized_composite=False): @tvm.ir.register_op_attr("nn.global_max_pool2d", "target.arm_compute_lib") -def global_max_pool2d(attrs, args): +def global_max_pool2d(expr): """Check if the external ACL codegen for gloval_maxpool2d should be used.""" + attrs, args = expr.attrs, expr.args typ = args[0].checked_type if typ.dtype not in ["float32", "uint8"]: return False @@ -330,8 +337,9 @@ def global_max_pool2d(attrs, args): @tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.arm_compute_lib") -def global_avg_pool2d(attrs, args): +def global_avg_pool2d(expr): """Check if the external ACL codegen for global_avgpool2d should be used.""" + attrs, args = expr.attrs, expr.args typ = args[0].checked_type if typ.dtype not in ["float32"]: return False @@ -341,16 +349,18 @@ def global_avg_pool2d(attrs, args): @tvm.ir.register_op_attr("maximum", "target.arm_compute_lib") -def maximum(attrs, args): +def maximum(expr): """Check if the external ACL codegen for maximum should be used.""" + args = expr.args type_a = args[0].checked_type type_b = args[0].checked_type return (type_a.dtype == "float32") and (type_b.dtype == "float32") @tvm.ir.register_op_attr("add", "target.arm_compute_lib") -def add(attrs, args): +def add(expr): """Check if the external ACL codegen for add should be used.""" + args = expr.args for typ in [args[0].checked_type, args[1].checked_type]: if typ.dtype != "float32": return False @@ -359,8 +369,9 @@ def add(attrs, args): @tvm.ir.register_op_attr("qnn.add", "target.arm_compute_lib") -def qnn_add(attrs, args): +def qnn_add(expr): """Check if the external ACL codegen for add should be used.""" + args = expr.args for typ in [args[0].checked_type, args[1].checked_type]: if typ.dtype != "uint8": return False diff --git a/python/tvm/relay/op/contrib/coreml.py b/python/tvm/relay/op/contrib/coreml.py index 105009a9f9b0..c1c012199cec 100644 --- a/python/tvm/relay/op/contrib/coreml.py +++ b/python/tvm/relay/op/contrib/coreml.py @@ -31,7 +31,8 @@ def _register_coreml_op(op_name): """ - def _check_supported(attrs, args): + def _check_supported(expr): + attrs, args = expr.attrs, expr.args if op_name == "nn.conv2d": if not isinstance(args[1], Constant): return False diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 816cb3818409..79bd02db164b 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -53,7 +53,7 @@ def _register_external_op_helper(op_name, supported=True): """ @tvm.ir.register_op_attr(op_name, "target.dnnl") - def _func_wrapper(attrs, args): + def _func_wrapper(expr): return supported return _func_wrapper diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 3c676f4d9623..3a05011242e7 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -128,21 +128,23 @@ def _is_ethosn_composite(node): @tvm.ir.register_op_attr("nn.max_pool2d", "target.ethos-n") -def max_pool2d(attrs, args): +def max_pool2d(expr): """Check if a max pool2d is supported by Ethos-N.""" if not ethosn_available(): return False + attrs, args = expr.attrs, expr.args pool = tvm.relay.nn.max_pool2d(*args, **attrs) return support.max_pool2d(pool) @tvm.ir.register_op_attr("reshape", "target.ethos-n") -def reshape(attrs, args): +def reshape(expr): """Check if a reshape is supported by Ethos-N.""" if not ethosn_available(): return False + attrs, args = expr.attrs, expr.args if not _is_ethosn_composite(args[0]): return False @@ -151,21 +153,23 @@ def reshape(attrs, args): @tvm.ir.register_op_attr("qnn.add", "target.ethos-n") -def qnn_add(attrs, args): +def qnn_add(expr): """Check if an addition is supported by Ethos-N.""" if not ethosn_available(): return False + args = expr.args add = _qnn.op.add(*args) return support.addition(add) @tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n") -def qnn_concatenate(attrs, args): +def qnn_concatenate(expr): """Check if a concatenate is supported by Ethos-N.""" if not ethosn_available(): return False + attrs, args = expr.attrs, expr.args conc = _qnn.op.concatenate(*args, **attrs) if not support.concatenate(conc): return False @@ -190,11 +194,12 @@ def qnn_concatenate(attrs, args): @tvm.ir.register_op_attr("split", "target.ethos-n") -def split(attrs, args): +def split(expr): """Check if a split is supported by Ethos-N.""" if not ethosn_available(): return False + attrs, args = expr.attrs, expr.args if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): sp = tvm.relay.split( *args, indices_or_sections=attrs["indices_or_sections"].value, axis=attrs["axis"] @@ -210,11 +215,12 @@ def split(attrs, args): @tvm.ir.register_op_attr("nn.depth_to_space", "target.ethos-n") -def depth_to_space(attrs, args): +def depth_to_space(expr): """Check if a depth_to_space is supported by Ethos-N.""" if not ethosn_available(): return False + attrs, args = expr.attrs, expr.args depth = tvm.relay.nn.depth_to_space(*args, **attrs) if not support.depth_to_space(depth): return False @@ -223,11 +229,12 @@ def depth_to_space(attrs, args): @tvm.ir.register_op_attr("clip", "target.ethos-n") -def clip(attrs, args): +def clip(expr): """Check if a clip is supported by Ethos-N.""" if not ethosn_available(): return False + attrs, args = expr.attrs, expr.args c = tvm.relay.clip(*args, **attrs) if not support.relu(c): return False diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index a0e23a043a72..24c468fee0fe 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -157,7 +157,8 @@ def partition_for_tensorrt( def _register_external_op_helper_with_checker(op_name, checker): @tvm.ir.register_op_attr(op_name, "target.tensorrt") - def _func_wrapper(attrs, args): + def _func_wrapper(expr): + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -192,9 +193,10 @@ def _register_external_op_helper(op_name, supported=True): @tvm.ir.register_op_attr("add", "target.tensorrt") -def add_annotate_fn(attrs, args): # pylint: disable=unused-variable +def add_annotate_fn(expr): # pylint: disable=unused-variable """Check if add is supported by TensorRT.""" + args = expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -211,8 +213,10 @@ def add_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt") -def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable +def batch_norm_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.batch_norm is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -223,8 +227,10 @@ def batch_norm_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.softmax", "target.tensorrt") -def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable +def softmax_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.softmax is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -235,8 +241,10 @@ def softmax_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt") -def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def conv2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -253,8 +261,10 @@ def conv2d_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.dense", "target.tensorrt") -def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable +def dense_annotate_fn(expr): # pylint: disable=unused-variable """Check if dense is supported by TensorRT.""" + + args = expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -270,8 +280,10 @@ def dense_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt") -def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable +def bias_add_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.bias_add is supported by TensorRT.""" + + args = expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -283,8 +295,10 @@ def bias_add_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt") -def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.max_pool2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -298,8 +312,10 @@ def max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt") -def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.avg_pool2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -326,8 +342,10 @@ def avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt") -def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def global_max_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.global_max_pool2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -338,8 +356,10 @@ def global_max_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-varia @tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt") -def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def global_avg_pool_2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.global_avg_pool2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -350,8 +370,10 @@ def global_avg_pool_2d_annotate_fn(attrs, args): # pylint: disable=unused-varia @tvm.ir.register_op_attr("expand_dims", "target.tensorrt") -def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable +def expand_dims_annotate_fn(expr): # pylint: disable=unused-variable """Check if expand_dims is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -362,8 +384,10 @@ def expand_dims_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("squeeze", "target.tensorrt") -def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable +def squeeze_annotate_fn(expr): # pylint: disable=unused-variable """Check if squeeze is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -377,8 +401,10 @@ def squeeze_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("concatenate", "target.tensorrt") -def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable +def concatenate_annotate_fn(expr): # pylint: disable=unused-variable """Check if concatenate is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.dtype != "float32" for x in args[0].checked_type.fields]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -396,8 +422,10 @@ def concatenate_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt") -def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable +def conv2d_transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv2d_transpose is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -419,8 +447,10 @@ def conv2d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variabl @tvm.ir.register_op_attr("transpose", "target.tensorrt") -def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable +def transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if transpose is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -431,8 +461,10 @@ def transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("layout_transform", "target.tensorrt") -def layout_transform_annotate_fn(attrs, args): # pylint: disable=unused-variable +def layout_transform_annotate_fn(expr): # pylint: disable=unused-variable """Check if layout_transform is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -450,8 +482,10 @@ def layout_transform_annotate_fn(attrs, args): # pylint: disable=unused-variabl @tvm.ir.register_op_attr("reshape", "target.tensorrt") -def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable +def reshape_annotate_fn(expr): # pylint: disable=unused-variable """Check if reshape is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if args[0].checked_type.dtype != "float32": logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -481,8 +515,10 @@ def reshape_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.pad", "target.tensorrt") -def pad_annotate_fn(attrs, args): # pylint: disable=unused-variable +def pad_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.pad is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -543,8 +579,10 @@ def _func_wrapper(attrs, args, op_name): @tvm.ir.register_op_attr("strided_slice", "target.tensorrt") -def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable +def strided_slice_annotate_fn(expr): # pylint: disable=unused-variable """Check if strided_slice is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if args[0].checked_type.dtype != "float32": logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -567,8 +605,10 @@ def strided_slice_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt") -def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def adapative_max_pool2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.adaptive_max_pool2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -579,8 +619,10 @@ def adapative_max_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-var @tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt") -def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def adapative_avg_pool2d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.adaptive_avg_pool2d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -591,8 +633,10 @@ def adapative_avg_pool2d_annotate_fn(attrs, args): # pylint: disable=unused-var @tvm.ir.register_op_attr("nn.conv3d", "target.tensorrt") -def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def conv3d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv3d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -611,8 +655,10 @@ def conv3d_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.max_pool3d", "target.tensorrt") -def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def max_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.max_pool3d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -625,8 +671,10 @@ def max_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.avg_pool3d", "target.tensorrt") -def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable +def avg_pool_3d_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.avg_pool3d is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False @@ -639,8 +687,10 @@ def avg_pool_3d_annotate_fn(attrs, args): # pylint: disable=unused-variable @tvm.ir.register_op_attr("nn.conv3d_transpose", "target.tensorrt") -def conv3d_transpose_annotate_fn(attrs, args): # pylint: disable=unused-variable +def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable """Check if nn.conv3d_transpose is supported by TensorRT.""" + + attrs, args = expr.attrs, expr.args if any([x.checked_type.dtype != "float32" for x in args]): logger.info("Only float32 inputs are supported for TensorRT.") return False diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 9d160b26f1ad..d5f1e4cc1752 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -178,7 +178,8 @@ class AnnotateTargetRewriter : public ExprRewriter { continue; } auto fannotate = Op::GetAttrMap("target." + std::string(target)); - if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) { + const Expr& ex = GetRef(pre); + if (fannotate.count(op) && fannotate[op](ex)) { supported_targets.push_back(target); } } diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 106909e16fa7..325826d183da 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -179,7 +179,7 @@ def test_extern_dnnl_mobilenet(): def test_multiple_ends(): @tvm.ir.register_op_attr("nn.relu", "target.test") - def relu(attrs, args): # pylint: disable=unused-variable + def relu(expr): # pylint: disable=unused-variable return True def before(): @@ -221,8 +221,8 @@ def test_type_propagation(): target = "test_type_propagation" @tvm.ir.register_op_attr("nn.relu", "target." + target) - def relu(attrs, args): # pylint: disable=unused-variable - return args[0].checked_type.dtype == "float32" + def relu(expr): # pylint: disable=unused-variable + return expr.args[0].checked_type.dtype == "float32" def before(): x = relay.var("x", shape=(10, 10)) @@ -240,11 +240,11 @@ def test_tuple(): target = "test_tuple" @tvm.ir.register_op_attr("nn.relu", "target." + target) - def relu(attrs, args): # pylint: disable=unused-variable + def relu(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("concatenate", "target." + target) - def concatenate(attrs, args): # pylint: disable=unused-variable + def concatenate(expr): # pylint: disable=unused-variable return True """Test that TupleNode is included in annotation when surrounded by supported nodes.""" @@ -331,11 +331,11 @@ def after(): def test_multiple_runs(): @tvm.ir.register_op_attr("nn.relu", "target.A") - def relu(attrs, args): # pylint: disable=unused-variable + def relu(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("add", "target.B") - def add(attrs, args): # pylint: disable=unused-variable + def add(expr): # pylint: disable=unused-variable return True def before(): @@ -359,19 +359,19 @@ def test_if_else(): target = "test_if_else" @tvm.ir.register_op_attr("equal", "target." + target) - def relu(attrs, args): # pylint: disable=unused-variable + def relu(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("tanh", "target." + target) - def tanh(attrs, args): # pylint: disable=unused-variable + def tanh(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("sigmoid", "target." + target) - def sigmoid(attrs, args): # pylint: disable=unused-variable + def sigmoid(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("erf", "target." + target) - def erf(attrs, args): # pylint: disable=unused-variable + def erf(expr): # pylint: disable=unused-variable return True """Test that If-else nodes compiles correctly when surrounded by supported nodes.""" @@ -430,15 +430,15 @@ def test_while_let(): target = "test_while_let" @tvm.ir.register_op_attr("less", "target." + target) - def less(attrs, args): # pylint: disable=unused-variable + def less(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("add", "target." + target) - def add(attrs, args): # pylint: disable=unused-variable + def add(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("zeros_like", "target." + target) - def zeros_like(attrs, args): # pylint: disable=unused-variable + def zeros_like(expr): # pylint: disable=unused-variable return True """Test that let nodes compiles correctly when surrounded by other nodes.""" @@ -514,15 +514,15 @@ def test_if_free_vars(): target = "test_if_free_vars" @tvm.ir.register_op_attr("equal", "target." + target) - def equal(attrs, args): # pylint: disable=unused-variable + def equal(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("sigmoid", "target." + target) - def sigmoid(attrs, args): # pylint: disable=unused-variable + def sigmoid(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("erf", "target." + target) - def erf(attrs, args): # pylint: disable=unused-variable + def erf(expr): # pylint: disable=unused-variable return True """Test that If-else nodes compiles correctly when surrounded by free variables""" diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 8d0e2d5e22e0..059d0b4c8af8 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1035,7 +1035,7 @@ def test_duplicate_outputs(): target = "test_duplicate_outputs" @tvm.ir.register_op_attr("abs", "target." + target) - def abs(attrs, args): # pylint: disable=unused-variable + def abs(expr): # pylint: disable=unused-variable return True def create_graph(): @@ -1096,11 +1096,11 @@ def test_duplicate_merge_and_tuplegetitem(): target = "test_duplicate_merge_and_tuplegetitem" @tvm.ir.register_op_attr("nn.batch_norm", "target." + target) - def batch_norm(attrs, args): # pylint: disable=unused-variable + def batch_norm(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("nn.relu", "target." + target) - def relu(attrs, args): # pylint: disable=unused-variable + def relu(expr): # pylint: disable=unused-variable return True def create_graph(): @@ -1177,7 +1177,7 @@ def expected(): def test_constant_tuples(): @tvm.ir.register_op_attr("qnn.concatenate", "target.const_tuples") - def add(attrs, args): # pylint: disable=unused-variable + def add(expr): # pylint: disable=unused-variable return True def create_graph(): @@ -1223,11 +1223,11 @@ def test_flatten_tuple_output(): target = "test_flatten_tuple_output" @tvm.ir.register_op_attr("split", "target." + target) - def split(attrs, args): # pylint: disable=unused-variable + def split(expr): # pylint: disable=unused-variable return True @tvm.ir.register_op_attr("abs", "target." + target) - def abs(attrs, args): # pylint: disable=unused-variable + def abs(expr): # pylint: disable=unused-variable return True def create_graph():