diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index d5c9edec51..bd05e1814d 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -97,6 +97,18 @@ struct PrintAttrs : public tvm::AttrsNode { } }; +struct AssertOpAttrs : public tvm::AttrsNode { + std::string format; + TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") { + TVM_ATTR_FIELD(format) + .describe( + "Python-style format string to use for displaying " + "an error message if the assert fails. " + "Ignored if empty.") + .set_default(""); + } +}; + } // namespace relax } // namespace tvm #endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index b851aabfe2..6bfd0d0daa 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -107,6 +107,21 @@ class NameTable { */ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); +/*! + * \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype). + * + * \param ty The input type. + * \param permit_unknown_rank If true, it will permit the input type to have unknown rank + * (ndim of -1), which will require a dynamic check. + * \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype + * (namely, void), which will require a dynamic check. + * + * \return True iff the input type is a boolean scalar type (or, depending on options, has unknown + * rank or dtype) + */ +TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, + bool permit_unknown_dtype = true); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 303f3cf1c6..f921d367da 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -13,21 +13,24 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations +# pylint: disable=redefined-builtin """The base Relax operators.""" -from typing import Union, List, Optional +from typing import List, Optional, Union import tvm from tvm.runtime.object import Object -from . import _ffi_api -from ..expr import Expr, ShapeExpr, Tuple, Call -from ..ty import DynTensorType, TupleType from ...ir import Array +from ..expr import Call, Expr, ExternFunc, ShapeExpr, Tuple +from ..ty import DynTensorType, TupleType +from . import _ffi_api + +py_print = print # pylint: disable=invalid-name def call_tir( - func: Expr, - args: Union[Tuple, List[Expr]], + func: Union[str, Expr], + args: Union[Expr, Tuple, List[Expr]], shape: Union[Tuple, ShapeExpr, List[int]], dtype: Union[str, List[str]], tir_vars: Optional[ShapeExpr] = None, @@ -37,10 +40,10 @@ def call_tir( Parameters ---------- - func : Expr + func : Union[str, Expr] The destination-passing-style function, can be ExternFunc or PrimFunc. - args : Union[Tuple, List[Expr]] + args : Union[Expr, Tuple, List[Expr]] The input arguments. shape: Union[Tuple, ShapeExpr, List[int]] @@ -57,9 +60,15 @@ def call_tir( ret: Call A call node for the call_tir operator. """ + if isinstance(func, str): + func = ExternFunc(func) + if isinstance(shape, (list, tuple, Array)): shape = ShapeExpr(shape) + if isinstance(args, Expr): + args = Tuple((args,)) + if isinstance(args, (list, tuple)): args = Tuple(args) @@ -131,17 +140,41 @@ def invoke_closure( return _ffi_api.invoke_closure(closure, args) +def render_object(val: tvm.Object) -> str: + """ + Given a TVM Object, renders it in string form. Used for Relax printing and assertions. + + Parameters + ---------- + val: tvm.Object + An object to render + + Returns + ------- + ret: str + A string representing the value, ideally human-readable + """ + if isinstance(val, tvm.runtime.ndarray.NDArray): + return str(val) + # no pretty-printer by default, so if we don't handle this, + # then we can't look inside tuples + if isinstance(val, tvm.runtime.container.ADT): + # the fields array of an ADT cannot be directly accessed in Python + # so we have to get the length and index into the fields separately + fields = ", ".join([render_object(val[i]) for i in range(len(val))]) + # special case: tag = 0 is a tuple + if val.tag == 0: + return f"({fields})" + return f"ADT(tag={val.tag}, fields=[{fields}])" + return str(val) + + @tvm.register_func("relax.run.print") -def relax_print(*args: List[any]) -> None: +def relax_print(format_str: str, *format_args: tvm.Object) -> None: """ Takes a list of values to print, formats with the given format string. If the format string is empty, simply prints. - Since this function is called as a PackedFunc from the generated code, - we cannot have it be variadic _and_ have an optional format string attribute - except by taking in all the arguments as a single list. The last argument - should be a format string. - Call from TVM script like this: `relax.print(value1, value2, ..., valueN, format=format_str)` or @@ -149,36 +182,130 @@ def relax_print(*args: List[any]) -> None: Parameters ---------- - vals: List[Object] + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[Object] + The values to print. + """ + val_strs = map(render_object, format_args) + if format_str == "": + py_print(*val_strs) + else: + py_print(format_str.format(*val_strs)) + + +def print(values: Union[Expr, List[Expr]], format: str) -> Expr: + """Print op to print the values + + Parameters + ---------- + values : List[Expr] The values to print. format_str: str - The last argument is a Python-style format string for printing the value + The format string. + + Returns + ------- + result : Expr + A relax Call, which will print the value during runtime. """ + if isinstance(values, Expr): + values = [values] + return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member + + +@tvm.register_func("relax.run.assert_op") +def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None: + """ + A variadic function. The first value serves as the assertion condition: + If the condition is true, then the operator does nothing. + If the condition is false, then the operator raises an assertion error. + + Arguments after the first value serve as format arguments for the error message; + the last argument must be a format string for the error message (empty by default). + If the format string is the empty string, then the error message will simply include + a comma-separated list of the format arguments. + The condition argument is not included in the format string. + + Parameters + ---------- + condition: tvm.Object + The assertion condition. Must be a boolean scalar. - # there is no way to have a keyword arg to a packed function, - # so the format string is always the last argument - format_str = args[-1] + format_str: str + The last argument is a Python-style format string for printing the value + + format_args: List[tvm.Object] + Values used for formatting the string. + """ if not isinstance(format_str, str): - raise ValueError("No valid format string given.") - - def render(val: tvm.Object) -> str: - if isinstance(val, tvm.runtime.ndarray.NDArray): - return str(val) - # no pretty-printer by default, so if we don't handle this, - # then we can't look inside tuples - if isinstance(val, tvm.runtime.container.ADT): - # the fields array of an ADT cannot be directly accessed in Python - # so we have to get the length and index into the fields separately - fields = ", ".join([render(val[i]) for i in range(len(val))]) - # special case: tag = 0 is a tuple - if val.tag == 0: - return f"({fields})" - return f"ADT(tag={val.tag}, fields=[{fields}])" - return str(val) + raise ValueError( + f"The format string argument to assert must be a string, given {type(format_str)})" + ) + + # should be guaranteed by the type system + if not isinstance(condition, tvm.runtime.ndarray.NDArray): + raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") + + # may happen if the original program had unknown shape or dtype for the tensor's type + dtype = condition.dtype + if dtype != "bool": + raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") + shape = condition.shape + if len(shape) != 0: + raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") + + val = condition.numpy() + if not val: + error_message = "Assertion Failed" + if format_args or format_str != "": + rendered = map(render_object, format_args) + if format_str != "": + error_message = format_str.format(*rendered) + else: + error_message = ", ".join(rendered) + raise AssertionError(error_message) + + +def assert_op(condition: Expr, format_args: Optional[List[Expr]] = None, format: str = "") -> Expr: + """ + Create a call to Relax's assert_op operation (`assert` is reserved in Python, + so the name must be distinct). - val_strs = map(render, args[:-1]) - if format_str == "": - print(*val_strs) - else: - print(format_str.format(*val_strs)) + Parameters + ---------- + condition: Expr + The assertion condition. + + format_args: List[Expr] + Format arguments for the error message if the condition fails. + + format_str: str + The format string for the error message. + + Returns + ------- + result : Expr + A Call to the Relax assert operation. + """ + if format_args is None: + format_args = [] + return _ffi_api.assert_op(condition, format_args, format) # type: ignore + + +def shape_of(expr: Expr) -> Expr: + """Get shape of a tensor. + + Parameters + ---------- + expr : Expr + The input Expr. + + Returns + ------- + result : Expr + A relax Call, which gets the shape of the input + """ + return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 6b27c29232..f8ef3107de 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -42,3 +42,8 @@ class UniqueAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.PrintAttrs") class PrintAttrs(Attrs): """Attributes used for the print operator""" + + +@tvm._ffi.register_object("relax.attrs.AssertOpAttrs") +class AssertOpAttrs(Attrs): + """Attributes used for the assert operator""" diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 8919d4d98d..d8d09ddd88 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -385,7 +385,14 @@ class CodeGenVM : public ExprFunctor { } if (call_node->op == print_op_) { auto print_attrs = call_node->attrs.as(); - args.push_back(EmitConstantFromValue(print_attrs->format)); + // format string is the first argument + args.insert(args.begin(), EmitConstantFromValue(print_attrs->format)); + return; + } + if (call_node->op == assert_op_) { + auto assert_attrs = call_node->attrs.as(); + // format string comes before the format args + args.insert(args.begin() + 1, EmitConstantFromValue(assert_attrs->format)); return; } LOG(FATAL) << "Support for attributes of Op " << call_node->op @@ -520,6 +527,7 @@ class CodeGenVM : public ExprFunctor { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& unique_op_ = Op::Get("relax.unique"); const Op& print_op_ = Op::Get("relax.print"); + const Op& assert_op_ = Op::Get("relax.assert_op"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); }; diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index f08e02f0cb..39b8a0b9a5 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "op_common.h" @@ -118,6 +119,51 @@ Expr MakePrint(Array vals, std::string format_str) { TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint); +// assert_op + +// can't actually name it assert or else Python will consider it a syntax error + +Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) { + // Ensure that the condition argument is a boolean scalar. + // Also permitted is a tensor with unknown shape and unknown dtype + // (checked dynamically in that case). Returns void. + if (call->args.size() < 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Assert must have at least one argument (the condition)."); + } + Type arg_type = call->args[0]->checked_type(); + if (!IsBoolScalarType(arg_type)) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The argument to assert must be a boolean scalar type, but received " + << arg_type); + } + return VoidType(); +} + +TVM_REGISTER_NODE_TYPE(AssertOpAttrs); + +RELAY_REGISTER_OP("relax.assert_op") + .set_attrs_type() + .set_num_inputs(-1) + .add_argument("vals", "Array", + "The first value is used as the assertion condition. The others are used as " + "format arguments if there is an error.") + .set_attr("FInferType", InferAssertType) + .set_attr("FCallPacked", "relax.run.assert_op"); + +Expr MakeAssertOp(Expr condition, Array vals, std::string format) { + auto attrs = make_object(); + attrs->format = format; + static const Op& op = Op::Get("relax.assert_op"); + Array args = {condition}; + for (auto val : vals) { + args.push_back(val); + } + return Call(op, args, Attrs(attrs)); +} + +TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp); + // make_closure RELAY_REGISTER_OP("relax.make_closure") diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 4484bd3281..75a882de45 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -67,5 +67,15 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } +bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) { + const DynTensorTypeNode* tt = ty.as(); + if (!tt) { + return false; + } + bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void()); + bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1); + return correct_dtype && correct_rank; +} + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index dba0d81719..acb560d9fc 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -21,6 +21,7 @@ import pytest import tvm from tvm import relax +from tvm._ffi.base import TVMError from tvm.script import relax as R @@ -88,5 +89,60 @@ def test_print(): sys.stdout = stdout +@tvm.script.ir_module +class AssertOpTest: + @R.function + def passes(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(True)) + return x + + @R.function + def pass_with_args(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(True), x, format="You won't see me") + return x + + @R.function + def simple_fail(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(False)) + return x + + @R.function + def fail_with_message(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(False), format="I failed...") + return x + + @R.function + def fail_with_args(x: Tensor((), "int32")): + # no format + p1 = relax.assert_op(relax.const(False), x, x) + return x + + @R.function + def fail_with_formatted_message(x: Tensor((), "int32")): + p1 = relax.assert_op(relax.const(False), x, format="Number: {}") + return x + + +def test_assert_op(): + def check_assertion_error(func_name, func_arg, expected_message): + passed = False + try: + run_cpu(AssertOpTest, func_name, func_arg) + passed = True + except TVMError as e: + # TVM will print out a TVMError that will contain the + # generated error at the bottom of a stack trace + assert "AssertionError" in e.args[0] + assert expected_message in e.args[0] + assert not passed + + run_cpu(AssertOpTest, "passes", tvm.nd.array(1)) + run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(2)) + check_assertion_error("simple_fail", tvm.nd.array(3), "Assertion Failed") + check_assertion_error("fail_with_message", tvm.nd.array(4), "I failed...") + check_assertion_error("fail_with_args", tvm.nd.array(5), "5, 5") + check_assertion_error("fail_with_formatted_message", tvm.nd.array(6), "Number: 6") + + if __name__ == "__main__": pytest.main([__file__])