From 28222657a3449db7d594d8eae0a19351fa4ed616 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 29 Dec 2022 18:33:17 -0500 Subject: [PATCH] [REFACTOR] StructInfo M3: MatchShape=>MatchCast (#323) * Introduce match cast, and code changes along * add match_cast parser support (#9) * Match cast support for VMShapeLower CanonicalizeBinding * Remove `match_shape` (#12) * Refactor ExprVisitor/Mutator to consider Expr in StructInfo. Co-authored-by: Siyuan Feng --- include/tvm/relax/block_builder.h | 10 +- include/tvm/relax/expr.h | 58 ++-- include/tvm/relax/expr_functor.h | 139 ++++++++-- include/tvm/relax/ir_functor.h | 4 +- include/tvm/script/ir_builder/relax/ir.h | 16 +- python/tvm/relax/__init__.py | 3 +- python/tvm/relax/block_builder.py | 16 +- python/tvm/relax/expr.py | 30 +- python/tvm/relax/expr_functor.py | 40 +-- python/tvm/relax/testing/ast_printer.py | 16 +- python/tvm/script/ir_builder/relax/ir.py | 20 +- python/tvm/script/parser/relax/__init__.py | 11 +- python/tvm/script/parser/relax/entry.py | 21 +- python/tvm/script/parser/relax/parser.py | 12 +- src/relax/analysis/analysis.cc | 6 +- src/relax/analysis/var2value.cc | 4 +- src/relax/analysis/well_formed.cc | 17 +- .../contrib/codegen_json/codegen_json.h | 6 +- src/relax/backend/vm/vm_shape_lower.cc | 25 +- src/relax/ir/binding_rewrite.cc | 19 +- src/relax/ir/block_builder.cc | 91 +++--- src/relax/ir/emit_te.cc | 2 +- src/relax/ir/expr.cc | 19 +- src/relax/ir/expr_functor.cc | 261 ++++++++++++------ src/relax/ir/transform.cc | 60 ++-- src/relax/transform/bind_params.cc | 2 +- src/relax/transform/canonicalize_bindings.cc | 31 +-- src/relax/transform/fuse_ops.cc | 29 +- src/relax/transform/fuse_tir.cc | 5 +- src/relax/transform/normalize.cc | 14 +- src/relay/printer/relax_script_printer.cc | 24 +- src/relay/printer/text_printer.h | 2 +- src/script/ir_builder/relax/frame.cc | 13 +- src/script/ir_builder/relax/ir.cc | 26 +- src/script/ir_builder/relax/utils.h | 7 +- tests/python/relax/test_analysis.py | 10 +- tests/python/relax/test_ast_printer.py | 44 +-- .../python/relax/test_autotir_integration.py | 4 +- tests/python/relax/test_blockbuilder.py | 33 ++- tests/python/relax/test_expr.py | 49 +++- tests/python/relax/test_expr_functor.py | 62 ++--- tests/python/relax/test_parser.py | 15 +- tests/python/relax/test_printer.py | 10 +- .../python/relax/test_structual_equal_hash.py | 8 +- tests/python/relax/test_transform.py | 10 +- .../test_transform_canonicalize_bindings.py | 48 +--- .../relax/test_transform_fold_constant.py | 4 +- .../python/relax/test_transform_normalize.py | 13 +- .../python/relax/test_tvmscript_ir_builder.py | 14 +- tests/python/relax/test_tvmscript_parser.py | 49 ++-- tests/python/relax/test_vm.py | 4 +- tests/python/relay/test_dataflow_pattern.py | 4 +- 52 files changed, 789 insertions(+), 651 deletions(-) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index fce2151651..d4fc93ac7e 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -160,13 +160,13 @@ class BlockBuilderNode : public Object { virtual Var Emit(Expr expr, String name_hint = "") = 0; /*! - * \brief Emit a MatchShape. - * \param value The value of the MatchShape to be emitted. - * \param pattern The pattern of the MatchShape to be emitted. + * \brief Emit a MatchCast. + * \param value The input value. + * \param struct_info The struct info to be matched. * \param name_hint Name hint for the bound variable. - * \return The variable bound to the MatchShape. + * \return The variable bound to the MatchCast. */ - virtual Var EmitMatchShape(Expr value, Array pattern, String name_hint = "") = 0; + virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0; /*! * \brief Generate an output for the current dataflow block. diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e893affac9..2b562af453 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -531,12 +531,10 @@ class Constant : public Expr { /*! \brief The base class of a variable binding in Relax. */ class BindingNode : public Object { public: + /*! \brief The return variable to bound to. */ + Var var; mutable Span span; - void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } - bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const {} - static constexpr const char* _type_key = "relax.expr.Binding"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -555,51 +553,61 @@ class Binding : public ObjectRef { using ContainerType = BindingNode; }; -/*! \brief Symbolic shape match, binds the variable of the lhs with the rhs. */ -class MatchShape; -class MatchShapeNode : public BindingNode { +/*! + * \brief Runtime-match the value to the struct info. + * + * This operation does runtime check, populates the un-defined symbolic shape vars + * and vars in struct_info in first occurance, and insert equality assertions in + * other cases. + */ +class MatchCastNode : public BindingNode { public: + /*! \brief The input value to match cast. */ Expr value; - Array pattern; - Var var; + /*! \brief The struct info pattern to match to. */ + StructInfo struct_info; void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("pattern", &pattern); v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("struct_info", &struct_info); v->Visit("span", &span); } - bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const { + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { // NOTE: pattern can contain ShapeExpr which defines the vars - return equal(value, other->value) && equal.DefEqual(pattern, other->pattern) && - equal.DefEqual(var, other->var); + return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && + equal(value, other->value); } void SHashReduce(SHashReducer hash_reduce) const { // NOTE: pattern can contain ShapeExpr which defines the vars - hash_reduce(value); - hash_reduce.DefHash(pattern); hash_reduce.DefHash(var); + hash_reduce.DefHash(struct_info); + hash_reduce(value); } - static constexpr const char* _type_key = "relax.expr.MatchShape"; + static constexpr const char* _type_key = "relax.expr.MatchCast"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; - TVM_DECLARE_FINAL_OBJECT_INFO(MatchShapeNode, BindingNode); + TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode); }; -class MatchShape : public Binding { +/*! + * \brief Managed reference to MatchCastNode. + * \sa MatchCastNode + */ +class MatchCast : public Binding { public: - TVM_DLL explicit MatchShape(Expr value, Array pattern, Var var, Span span = Span()); - TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode); + TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode); }; -class VarBinding; class VarBindingNode : public BindingNode { public: - Var var; + /*! \brief The binding value. */ Expr value; void VisitAttrs(AttrVisitor* v) { @@ -628,8 +636,6 @@ class VarBinding : public Binding { TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode); }; -class BindingBlock; - class BindingBlockNode : public Object { public: mutable Span span; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 236233e6c7..9f4c2b58f8 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -29,9 +29,8 @@ #include #include #include -#include -#include -#include +#include +#include #include #include @@ -213,7 +212,7 @@ class ExprVisitor : public ExprFunctor { virtual void VisitBinding(const Binding& binding); // specific leaf level visitor functions virtual void VisitBinding_(const VarBindingNode* binding); - virtual void VisitBinding_(const MatchShapeNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); // second level dispatching based on binding value type. // these dispatching functions get called from first-level dispatch on VarBinding virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); @@ -244,6 +243,23 @@ class ExprVisitor : public ExprFunctor { * \note VisitExpr_(const VarNode*) will only visit the usage site of an Var */ virtual void VisitVarDef(const Var& var); + + /*! + * \brief Visit struct_info may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if we are not interested in Expr in StructInfo, we can + * override this function by a no-op. + * + * \param struct_info Input struct info field. + */ + virtual void VisitExprDepStructInfoField(const StructInfo& struct_info); + // specific leaf level visitor functions virtual void VisitVarDef_(const VarNode* var); virtual void VisitVarDef_(const DataflowVarNode* var); @@ -258,6 +274,30 @@ class ExprVisitor : public ExprFunctor { tvm::NodeFunctor; // initialize the vtable. static VisitBindingVTable InitVisitBindingVTable(); + /*! + * \brief Private internal struct info field visitor. + * + * Support default visiting of struct info field and recursive into + * their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprVisitor and StructInfoVisitor. + */ + class DefaultStructInfoFieldVisitor : public StructInfoVisitor { + public: + explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent); + + // Override defaults in struct info visitor. + void VisitStructInfoExprField(const Expr& expr) final; + void VisitStructInfoExprField(const PrimExpr& expr) final; + void VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprVisitor* parent_; + }; + // This visitor is not visible to child classes and only + // used to supportd default visiting behavior. + DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this}; }; void PostOrderVisit(const Expr& node, std::function fvisit); @@ -309,6 +349,64 @@ class ExprMutatorBase : public ExprFunctor { * Can be overloaded to transform the shape expressions. */ virtual PrimExpr VisitPrimExpr(const PrimExpr& expr); + + /*! + * \brief Visit struct_info that may recursively contain Expr/PrimExpr. + * + * By default, this function recurse into struct info such as + * TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr + * accordingly. It does not recurse into FunctionStructInfo as it does + * not contain Expr defined in the current scope. + * + * Pass writers can overload this function to change to other behaviors. + * For example, if in Expr in StructInfo won't change, we can + * override this function by an identity function. + * + * \param struct_info Input struct info field. + * \return The updated struct info. + */ + virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info); + + protected: + /*! + * \brief Check whether VisitExprDepStructInfoField change struct_info. + * \return Whether struct info changed. + * \note This function is used by mutator implementations to check if + * previous Expr update will trigger a change in struct_info. + * If change is detected, the implementation can generate a fresh + * node without struct_info, and trigger normalizer to re-derive. + */ + bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) { + if (const StructInfoNode* sinfo = struct_info.as()) { + return this->VisitExprDepStructInfoField(GetRef(sinfo)).same_as(struct_info); + } else { + return true; + } + } + + private: + /*! + * \brief Private internal struct info field visitor to support + * Default visiting of struct info field and recursive into their Expr fields. + * + * We use component instead of sub-classing so there can be other + * joint inheritance between ExprMutator and StructInfoMutator. + */ + class DefaultStructInfoFieldMutator : public StructInfoMutator { + public: + explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent); + + // Override defaults in struct info visitor. + Expr VisitStructInfoExprField(const Expr& expr) final; + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final; + StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final; + + private: + ExprMutatorBase* parent_; + }; + // This visitor is not visible to child classes and only + // used to supportd default visiting behavior. + DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this}; }; /*! @@ -324,7 +422,6 @@ class ExprMutator : public ExprMutatorBase { ExprMutator(Optional mod = NullOpt) { builder_ = BlockBuilder::Create(mod); } Expr VisitExpr(const Expr& expr) override; - Expr VisitExpr_(const TupleNode* op) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const DataflowVarNode* op) override; Expr VisitExpr_(const FunctionNode* op) override; @@ -338,7 +435,7 @@ class ExprMutator : public ExprMutatorBase { virtual void VisitBinding(const Binding& binding); // specific leaf level visitor functions virtual void VisitBinding_(const VarBindingNode* binding); - virtual void VisitBinding_(const MatchShapeNode* binding); + virtual void VisitBinding_(const MatchCastNode* binding); // second level dispatching based on binding value type. // these dispatching functions get called from first-level dispatch on VarBinding virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val); @@ -484,9 +581,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor { /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` * function. */ PackedFunc f_visit_var_binding_{nullptr}; - /*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)` + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` * function. */ - PackedFunc f_visit_match_shape_{nullptr}; + PackedFunc f_visit_match_cast_{nullptr}; /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` * function. */ PackedFunc f_visit_binding_block{nullptr}; @@ -523,8 +620,8 @@ class PyExprVisitorNode : public Object, public ExprVisitor { void VisitBinding_(const VarBindingNode* binding) PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, ExprVisitor::VisitBinding_(binding)); - void VisitBinding_(const MatchShapeNode* binding) - PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_shape_, + void VisitBinding_(const MatchCastNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_cast_, ExprVisitor::VisitBinding_(binding)); void VisitBindingBlock(const BindingBlock& block) @@ -602,7 +699,7 @@ class PyExprVisitor : public ObjectRef { * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* * binding)`. - * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode* + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* * binding)`. * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& * block)`. @@ -624,7 +721,7 @@ class PyExprVisitor : public ObjectRef { PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, - PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) { @@ -649,7 +746,7 @@ class PyExprVisitor : public ObjectRef { n->f_visit_op_ = f_visit_op_; n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; n->f_visit_var_binding_ = f_visit_var_binding_; - n->f_visit_match_shape_ = f_visit_match_shape_; + n->f_visit_match_cast_ = f_visit_match_cast_; n->f_visit_binding_block_ = f_visit_binding_block_; n->f_visit_dataflow_block_ = f_visit_dataflow_block_; n->f_visit_var_def_ = f_visit_var_def_; @@ -702,9 +799,9 @@ class PyExprMutatorNode : public Object, public ExprMutator { /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` * function. */ PackedFunc f_visit_var_binding_{nullptr}; - /*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)` + /*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)` * function. */ - PackedFunc f_visit_match_shape_{nullptr}; + PackedFunc f_visit_match_cast_{nullptr}; /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` * function. */ PackedFunc f_visit_binding_block{nullptr}; @@ -748,9 +845,9 @@ class PyExprMutatorNode : public Object, public ExprMutator { ExprMutator::VisitBinding_(binding); } - void VisitBinding_(const MatchShapeNode* binding) { - if (f_visit_match_shape_ != nullptr) - f_visit_match_shape_(GetRef(binding)); + void VisitBinding_(const MatchCastNode* binding) { + if (f_visit_match_cast_ != nullptr) + f_visit_match_cast_(GetRef(binding)); else ExprMutator::VisitBinding_(binding); } @@ -866,7 +963,7 @@ class PyExprMutator : public ObjectRef { * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* * binding)`. - * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode* + * \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode* * binding)`. * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& * block)`. @@ -889,7 +986,7 @@ class PyExprMutator : public ObjectRef { PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, - PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_, PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) { @@ -911,7 +1008,7 @@ class PyExprMutator : public ObjectRef { n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; n->f_visit_binding = f_visit_binding; n->f_visit_var_binding_ = f_visit_var_binding_; - n->f_visit_match_shape_ = f_visit_match_shape_; + n->f_visit_match_cast_ = f_visit_match_cast_; n->f_visit_binding_block = f_visit_binding_block; n->f_visit_binding_block_ = f_visit_binding_block_; n->f_visit_dataflow_block_ = f_visit_dataflow_block_; diff --git a/include/tvm/relax/ir_functor.h b/include/tvm/relax/ir_functor.h index 5615e00188..f162c5a28a 100644 --- a/include/tvm/relax/ir_functor.h +++ b/include/tvm/relax/ir_functor.h @@ -77,7 +77,7 @@ class IRFunctor { virtual R VisitNode_(const relax::VarNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::DataflowVarNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::ShapeExprNode* op, Args... args) IR_FUNCTOR_DEFAULT; - virtual R VisitNode_(const relax::MatchShapeNode* op, Args... args) IR_FUNCTOR_DEFAULT; + virtual R VisitNode_(const relax::MatchCastNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::VarBindingNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::BindingBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; virtual R VisitNode_(const relax::DataflowBlockNode* op, Args... args) IR_FUNCTOR_DEFAULT; @@ -103,7 +103,7 @@ class IRFunctor { RELAX_IR_FUNCTOR_DISPATCH(relax::VarNode); RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowVarNode); RELAX_IR_FUNCTOR_DISPATCH(relax::ShapeExprNode); - RELAX_IR_FUNCTOR_DISPATCH(relax::MatchShapeNode); + RELAX_IR_FUNCTOR_DISPATCH(relax::MatchCastNode); RELAX_IR_FUNCTOR_DISPATCH(relax::VarBindingNode); RELAX_IR_FUNCTOR_DISPATCH(relax::BindingBlockNode); RELAX_IR_FUNCTOR_DISPATCH(relax::DataflowBlockNode); diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 432d4fd340..f9f4588d81 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -111,15 +111,13 @@ TVM_DLL void DataflowBlockOutput(const Array& vars); TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value); /*! - * \brief Emit a match_shape binding to the last binding block frame. - * \param value The value of the MatchShape to be emitted. - * \param pattern The pattern of the MatchShape to be emitted. - * \param emit_var A boolean indicating if the MatchShape contains the emitted variable. - * \return The emitted var if `emit_var` is true. Otherwise, return `NullOpt`. - */ -TVM_DLL Optional EmitMatchShape(const tvm::relax::Expr& value, // - const Array& pattern, // - bool emit_var); + * \brief Emit a match_cast binding to the last binding block frame. + * \param value The value of the MatchCast to be emitted. + * \param struct_info The struct info of the MatchCast to be emitted. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info); ///////////////////////////// Type Deduce ////////////////////////////// diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index d19bc70f04..609c7514e5 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -28,6 +28,7 @@ from . import struct_info # Expr + from .expr import ( Expr, Span, @@ -37,7 +38,7 @@ Var, DataflowVar, Binding, - MatchShape, + MatchCast, VarBinding, BindingBlock, DataflowBlock, diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 679370bff6..acc0e14ca5 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -29,12 +29,12 @@ Var, ShapeExpr, GlobalVar, - PrimExpr, BindingBlock, Tuple, BaseFunc, Binding, ) +from .struct_info import StructInfo from .op.base import call_tir from . import _ffi_api @@ -540,23 +540,23 @@ def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32")) """ return self.emit(self.call_te(func, *args, **kwargs)) - def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var: - """Emit a MatchShape. + def match_cast(self, value: Expr, struct_info: StructInfo) -> Var: + """Emit a MatchCast. Parameters ---------- value : tvm.relax.Expr - The value of the MatchShape to be emitted. + The value of the MatchCast to be emitted. - pattern : List[PrimExpr] - The pattern of the MatchShape to be emitted. + struct_info : StructInfo + The struct info to be matched. Returns ------- ret : tvm.relax.Var - A newly created variable that gets bound to the call code. + A newly created variable that get bounds to be the casted result. """ - return _ffi_api.BlockBuilderEmitMatchShape(self, value, pattern) # type: ignore + return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: """Emit output for the current dataflow block or function. diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 17e36fca21..35f388dc6d 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -308,17 +308,35 @@ class Binding(Node): ... -@tvm._ffi.register_object("relax.expr.MatchShape") -class MatchShape(Binding): - """Symbolic shape match, binds the variable of the lhs with the rhs.""" +@tvm._ffi.register_object("relax.expr.MatchCast") +class MatchCast(Binding): + """Runtime-match the value to the struct info. + + This operation does runtime check, populates the un-defined symbolic shape vars + and vars in struct_info in the first occurrence, and insert equality assertions in + other cases. + + Parameters + ---------- + var: Var + The return variable that the match cast bind to. value: Expr - pattern: List[PrimExpr] + The input value expression. + + struct_info: tvm.relax.StructInfo + The struct info to match cast to. + """ + var: Var + struct_info: "tvm.relax.StructInfo" + value: Expr - def __init__(self, value: Expr, pattern: List[PrimExpr], var: Var, span: Span = None) -> None: + def __init__( + self, var: Var, value: Expr, struct_info: "tvm.relax.StructInfo", span: Span = None + ) -> None: self.__init_handle_by_constructor__( - _ffi_api.MatchShape, value, pattern, var, span # type: ignore + _ffi_api.MatchCast, var, value, struct_info, span # type: ignore ) diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index 7370d2aa14..bd32231376 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -29,7 +29,7 @@ from .expr import ShapeExpr from .expr import GlobalVar, SeqExpr, Tuple from .expr import Call, If, TupleGetItem -from .expr import Binding, MatchShape, VarBinding +from .expr import Binding, MatchCast, VarBinding from .expr import BindingBlock, DataflowBlock from .struct_info import StructInfo from ..relay import Id @@ -190,7 +190,7 @@ def visit_tuple_getitem_(self, op: TupleGetItem): def visit_var_binding_(self, binding: VarBinding): raise NotImplementedError() - def visit_match_shape_(self, binding: MatchShape): + def visit_match_cast_(self, binding: MatchCast): raise NotImplementedError() def visit_binding_block_(self, block: BindingBlock): @@ -206,8 +206,8 @@ def visit_dataflow_var_def_(self, var: DataflowVar): raise NotImplementedError() def visit_binding(self, binding: Binding): - if isinstance(binding, MatchShape): - self.visit_match_shape_(binding) + if isinstance(binding, MatchCast): + self.visit_match_cast_(binding) elif isinstance(binding, VarBinding): self.visit_var_binding_(binding) else: @@ -259,7 +259,7 @@ def __init__( f_visit_tuple_getitem_: Callable = None, f_visit_binding: Callable = None, f_visit_var_binding_: Callable = None, - f_visit_match_shape_: Callable = None, + f_visit_match_cast_: Callable = None, f_visit_binding_block: Callable = None, f_visit_binding_block_: Callable = None, f_visit_dataflow_block_: Callable = None, @@ -289,7 +289,7 @@ def __init__( f_visit_tuple_getitem_, f_visit_binding, f_visit_var_binding_, - f_visit_match_shape_, + f_visit_match_cast_, f_visit_binding_block, f_visit_binding_block_, f_visit_dataflow_block_, @@ -378,7 +378,7 @@ def MyExprVisitor(PyExprVisitor): "visit_tuple_getitem_", "visit_binding", "visit_var_binding_", - "visit_match_shape_", + "visit_match_cast_", "visit_binding_block", "visit_binding_block_", "visit_dataflow_block_", @@ -623,15 +623,15 @@ def visit_var_binding_(self, binding: VarBinding) -> None: # Using self._outer() to ref _PyExprVisitor return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore - def visit_match_shape_(self, binding: MatchShape) -> None: - """Visit MatchShape. - Users can customized this function to overwrite VisitBinding_(const MatchShapeNode* binding) + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) on the C++ side. Parameters ---------- - binding : MatchShape - The MatchShape to be visited. + binding : MatchCast + The MatchCast to be visited. """ # Using self._outer() to ref _PyExprVisitor return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) # type: ignore @@ -743,7 +743,7 @@ def __init__( f_visit_tuple_getitem_: Callable = None, f_visit_binding: Callable = None, f_visit_var_binding_: Callable = None, - f_visit_match_shape_: Callable = None, + f_visit_match_cast_: Callable = None, f_visit_binding_block: Callable = None, f_visit_binding_block_: Callable = None, f_visit_dataflow_block_: Callable = None, @@ -774,7 +774,7 @@ def __init__( f_visit_tuple_getitem_, f_visit_binding, f_visit_var_binding_, - f_visit_match_shape_, + f_visit_match_cast_, f_visit_binding_block, f_visit_binding_block_, f_visit_dataflow_block_, @@ -879,7 +879,7 @@ def MyExprMutator(PyExprMutator): "visit_tuple_getitem_", "visit_binding", "visit_var_binding_", - "visit_match_shape_", + "visit_match_cast_", "visit_binding_block", "visit_binding_block_", "visit_dataflow_block_", @@ -1208,15 +1208,15 @@ def visit_var_binding_(self, binding: VarBinding) -> None: # Using self._outer() to ref _PyExprMutator return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore - def visit_match_shape_(self, binding: MatchShape) -> None: - """Visit MatchShape. - Users can customized this function to overwrite VisitBinding_(const MatchShapeNode* binding) + def visit_match_cast_(self, binding: MatchCast) -> None: + """Visit MatchCast. + Users can customized this function to overwrite VisitBinding_(const MatchCastNode* binding) on the C++ side. Parameters ---------- - binding : MatchShape - The MatchShape to be visited. + binding : MatchCast + The MatchCast to be visited. """ # Using self._outer() to ref _PyExprMutator return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index f272b3aa5a..959e6e5ef0 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -312,24 +312,22 @@ def visit_binding_(self, binding: relax.Binding) -> str: """ Distinguish between binding types """ - if isinstance(binding, relax.MatchShape): - return self.visit_match_shape_(binding) + if isinstance(binding, relax.MatchCast): + return self.visit_match_cast_(binding) if isinstance(binding, relax.VarBinding): return self.visit_var_binding_(binding) raise ValueError(f"Invalid binding type in {binding}: {type(binding)}") - def visit_match_shape_(self, match_shape: relax.MatchShape) -> str: + def visit_match_cast_(self, match_cast: relax.MatchCast) -> str: """ Handle match shape """ fields = { - "value": self.visit_expr(match_shape.value), - "pattern": self.build_list(map(self.visit_prim_expr_, match_shape.pattern)), + "var": self.visit_expr(match_cast.var), + "value": self.visit_expr(match_cast.value), + "struct_info": self.visit_struct_info_(match_cast.struct_info), } - # not always defined - if match_shape.var: - fields["var"] = self.visit_expr(match_shape.var) - return self.build_ast_node("MatchShape", **fields) + return self.build_ast_node("MatchCast", **fields) def visit_var_binding_(self, var_binding: relax.VarBinding) -> str: """ diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index afc7a5cf5c..ad218e1996 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -281,23 +281,21 @@ def emit(value: Expr) -> Var: return _ffi_api.Emit(value) # pylint: disable=no-member # type: ignore -def emit_match_shape(value: Expr, pattern: List[PrimExpr], emit_var: bool) -> Optional[Var]: - """Emit a match_shape binding to the last binding block frame. +def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var: + """Emit a match_cast binding to the last binding block frame. Parameters ---------- value: Expr - The value of the MatchShape to be emitted. - pattern: List[PrimExpr] - The pattern of the MatchShape to be emitted. - emit_var: bool - A boolean indicating if the MatchShape contains the emitted variable. + The value of the MatchCast to be emitted. + struct_info: StructInfo + The struct_info of the MatchCast to be emitted. Returns ------- - var: Optional[Var] - The emitted var if `emit_var` is True. Otherwise, return `None`. + var: Var + The left side var of the emitted binding. """ - return _ffi_api.EmitMatchShape(value, pattern, emit_var) # type: ignore + return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore ############################# Type Deduce ############################## @@ -407,7 +405,7 @@ def RewriteSymbolicShape( "const", "dataflow", "emit", - "emit_match_shape", + "emit_match_cast", "ewise_fma", "func_attr", "func_name", diff --git a/python/tvm/script/parser/relax/__init__.py b/python/tvm/script/parser/relax/__init__.py index 4e4f924035..53bc3b3626 100644 --- a/python/tvm/script/parser/relax/__init__.py +++ b/python/tvm/script/parser/relax/__init__.py @@ -18,6 +18,13 @@ from ...ir_builder.relax import * # pylint: disable=redefined-builtin from ...ir_builder.relax import ir as _relax from . import parser as _parser -from .entry import Callable, Shape, Tensor, Tuple, function, match_shape +from .entry import Callable, Shape, Tensor, Tuple, function, match_cast -__all__ = _relax.__all__ + ["Callable", "Shape", "Tensor", "Tuple", "function", "match_shape"] +__all__ = _relax.__all__ + [ + "Callable", + "Shape", + "Tensor", + "Tuple", + "function", + "match_cast", +] diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 58d857d0f8..1f8caeb694 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -177,19 +177,20 @@ def __getitem__(self, keys) -> Var: Shape = ShapeProxy() -############################ R.match_shape ############################# -class MatchShapePair: + +############################ R.match_cast ############################# +class MatchCastPair: value: Expr - pattern: List[PrimExpr] + struct_info: StructInfo - def __init__(self, value: Expr, pattern: List[PrimExpr]) -> None: + def __init__(self, value: Expr, struct_info: StructInfo) -> None: self.value = value - self.pattern = pattern + self.struct_info = struct_info -def match_shape(value: Expr, pattern: List[PrimExpr]): +def match_cast(value: Expr, struct_info: StructInfo): if value is None: - raise ValueError("value of match_shape cannot be None") - if pattern is None: - raise ValueError("pattern of match_shape cannot be None") - return MatchShapePair(value, pattern) + raise ValueError("value of match_cast cannot be None") + if struct_info is None: + raise ValueError("struct_info of match_cast cannot be None") + return MatchCastPair(value, struct_info) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 51dfce26ef..3cb80acb86 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -29,7 +29,7 @@ from ...ir_builder import relax as R from ...ir_builder.base import IRBuilder from .._core import Parser, dispatch, doc -from .entry import MatchShapePair +from .entry import MatchCastPair def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: @@ -70,10 +70,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - assert var is not None IRBuilder.name(var_name, var) return var - elif isinstance(value, MatchShapePair): - var = R.emit_match_shape(value.value, value.pattern, emit_var=True) - # It's an internal check, so directly use assert here. - assert var is not None + elif isinstance(value, MatchCastPair): + var = R.emit_match_cast(value.value, value.struct_info) IRBuilder.name(var_name, var) return var else: @@ -150,9 +148,7 @@ def post_token_switch(self: Parser, node: doc.Expr) -> None: @dispatch.register(token="relax", type_name="Expr") def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: value = self.eval_expr(node.value) - if isinstance(value, MatchShapePair): - R.emit_match_shape(value.value, value.pattern, emit_var=False) - elif value is not None: + if value is not None: self.report_error(node, f"Unsupported Expr stmt type {value}.") diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index 6bdaa52810..1fac9f84ff 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -134,10 +134,8 @@ class VarVisitor : protected ExprVisitor { VisitVarDef(binding->var); } - void VisitBinding_(const MatchShapeNode* binding) final { - if (binding->var.defined()) { - MarkBounded(binding->var); - } + void VisitBinding_(const MatchCastNode* binding) final { + MarkBounded(binding->var); ExprVisitor::VisitBinding_(binding); } diff --git a/src/relax/analysis/var2value.cc b/src/relax/analysis/var2value.cc index 680a2a7261..d034afeb21 100644 --- a/src/relax/analysis/var2value.cc +++ b/src/relax/analysis/var2value.cc @@ -70,9 +70,9 @@ class Name2BindingAnalysis : public relax::ExprVisitor { name2bindings_[vname].push_back(GetRef(binding)); } - void VisitBinding_(const MatchShapeNode* binding) override { + void VisitBinding_(const MatchCastNode* binding) override { const auto& vname = binding->var->name_hint(); - name2bindings_[vname].push_back(GetRef(binding)); + name2bindings_[vname].push_back(GetRef(binding)); } }; diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 67f16f17f1..5859ef4bed 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -298,22 +298,13 @@ class WellFormedChecker : public relax::ExprVisitor, this->VisitVarDef(binding->var); } - void VisitBinding_(const MatchShapeNode* binding) final { + void VisitBinding_(const MatchCastNode* binding) final { this->VisitExpr(binding->value); // define the vars - WithMode(VisitMode::kMatchVarDef, [&]() { - for (PrimExpr expr : binding->pattern) { - this->VisitStructInfoExprField(expr); - } - }); + WithMode(VisitMode::kMatchVarDef, [&]() { this->VisitStructInfo(binding->struct_info); }); - for (PrimExpr expr : binding->pattern) { - this->VisitStructInfoExprField(expr); - } - - if (binding->var.defined()) { - this->VisitVarDef(binding->var); - } + this->VisitStructInfo(binding->struct_info); + this->VisitVarDef(binding->var); } void VisitBindingBlock_(const DataflowBlockNode* block) final { diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 10a769d1b6..809156bfeb 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -252,8 +252,8 @@ class JSONSerializer return VisitExpr(binding->value); } - std::vector VisitBinding_(const MatchShapeNode* binding) { - LOG(FATAL) << "JSON runtime currently doesn't shape expr\n"; + std::vector VisitBinding_(const MatchCastNode* binding) { + LOG(FATAL) << "JSON runtime currently doesn't match cast\n"; return {}; } @@ -262,7 +262,7 @@ class JSONSerializer if (const auto* node = binding.as()) { auto from_b = VisitBinding_(node); nodes.insert(nodes.end(), from_b.begin(), from_b.end()); - } else if (const auto* node = binding.as()) { + } else if (const auto* node = binding.as()) { auto from_b = VisitBinding_(node); nodes.insert(nodes.end(), from_b.begin(), from_b.end()); } else { diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 29a6f38990..703cc47280 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -59,6 +59,12 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { } } + void VisitBinding_(const MatchCastNode* op) final { + // Visit the match cast struct info so we can define + // the symbolic variables here. + this->VisitStructInfo(op->struct_info); + } + void VisitExpr_(const FunctionNode* op) final { // Do not recurse into function node as it is self-contained } @@ -99,12 +105,21 @@ class VMShapeLowerMutator : public ExprMutator { return builder_->GetContextIRModule(); } - void VisitBinding_(const MatchShapeNode* binding) override { + void VisitBinding_(const MatchCastNode* binding) override { + // TODO(@tqchen): match_cast support for general struct info Expr value = ExprMutator::VisitExpr(binding->value); - - // TODO(@yuchen): match_shape overloaded semantic: value is ShapeType - Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {value}), "sh"); - StoreShape(shape, binding->pattern); + auto* tinfo = binding->struct_info.as(); + ICHECK(tinfo != nullptr) << "Match cast only support TensorStructInfo for now"; + auto* shape_expr = tinfo->shape.as(); + + if (shape_expr) { + bool dyn_shape = std::any_of(shape_expr->values.begin(), shape_expr->values.end(), + [](const PrimExpr& e) { return !e->IsInstance(); }); + if (dyn_shape) { + Var shape = builder_->Emit(Call(ExternFunc("vm.builtin.shape_of"), {value}), "sh"); + StoreShape(shape, shape_expr->values); + } + } } using ExprMutator::VisitExpr_; diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index a6d6b27d81..d1afdd74cc 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -129,13 +129,12 @@ void DataflowBlockRewriteNode::Add(Binding binding) { auto p = [binding] { if (auto vb = binding.as()) { return std::make_pair(vb->var, vb->value); - } else if (auto ms = binding.as()) { - return std::make_pair(ms->var, ms->value); + } else if (auto mc = binding.as()) { + return std::make_pair(mc->var, mc->value); } LOG(FATAL) << "Unsupported binding type"; return std::make_pair(Var{}, Expr{}); }(); - Var var = p.first; Expr val = p.second; @@ -156,11 +155,7 @@ void DataflowBlockRewriteNode::Add(Binding binding) { size_t line_last_req_def = 0; for (size_t i = 0; i < dfb_.value()->bindings.size(); ++i) { auto line = dfb_.value()->bindings[i]; - if (auto varbind = line.as()) { - if (used_vars.find(varbind->var.get()) != used_vars.cend()) line_last_req_def = i; - } else if (auto mshape = line.as()) { - if (used_vars.find(mshape->var.get()) != used_vars.cend()) line_last_req_def = i; - } + if (used_vars.find(line->var.get()) != used_vars.cend()) line_last_req_def = i; } auto old_dfb = dfb_.value(); @@ -240,12 +235,8 @@ class RemoveUnusedVars : public ExprMutator { auto prev_dfb = GetRef(block); builder_->BeginDataflowBlock(); for (Binding binding : block->bindings) { - if (const auto* node = binding.as()) { - if (!unused_vars.count(node->var)) VisitBinding_(node); - } else if (const auto* node = binding.as()) { - if (!unused_vars.count(node->var)) VisitBinding_(node); - } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + if (!unused_vars.count(binding->var)) { + VisitBinding(binding); } } auto new_dfb = builder_->EndBlock(); diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index bea13438ff..08c427be65 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -194,24 +194,21 @@ class BlockBuilderImpl : public BlockBuilderNode { return this->Emit(expr, CurrentBlockFrame()->is_dataflow, name_hint); } - Var EmitMatchShape(Expr value, Array pattern, String name_hint) final { + Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint) final { value = this->Normalize(value); + CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) + << "It is impossible to match cast any value into the target struct_info. " + "But got value struct info: " + << GetStructInfo(value) << ", given struct info: " << struct_info; + + // NOTE: do match cast checking later in a pass. BlockFrame* cur_frame = CurrentBlockFrame(); Var var = CreateVar(cur_frame->is_dataflow, name_hint); + UpdateStructInfo(var, struct_info); - if (value->struct_info_.as()) { - UpdateStructInfo(var, ShapeStructInfo(pattern.size())); - } else if (const auto* tensor_sinfo = value->struct_info_.as()) { - UpdateStructInfo(var, TensorStructInfo(ShapeExpr(pattern), tensor_sinfo->dtype)); - } else { - this->ReportFatal( - Diagnostic::Error(value) - << "The value passed to EmitMatchShape must be of TensorStructInfo or ShapeStructInfo."); - } - - MatchShape match_shape = MatchShape(value, pattern, var); - cur_frame->bindings.push_back(match_shape); + MatchCast match_cast(var, value, struct_info); + cur_frame->bindings.push_back(match_cast); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. return var; @@ -228,31 +225,29 @@ class BlockBuilderImpl : public BlockBuilderNode { void EmitNormalized(Binding binding) final { BlockFrame* cur_frame = CurrentBlockFrame(); - if (auto* var_binding = binding.as()) { + if (const auto* var_binding = binding.as()) { if (!cur_frame->is_dataflow) { ICHECK(!var_binding->var.as()) - << "Cannot emit dataflowvar in non-dataflow block"; + << "Cannot emit dataflow var in non-dataflow block"; } // normalized check ICHECK(var_binding->var->struct_info_.defined()); ICHECK(var_binding->value->struct_info_.defined()); cur_frame->bindings.push_back(binding); binding_table_[var_binding->var->vid] = var_binding->value; - } else { - auto* match_shape = binding.as(); - ICHECK(match_shape); - if (match_shape->var.defined()) { - if (!cur_frame->is_dataflow) { - ICHECK(!match_shape->var.as()) - << "Cannot emit dataflowvar in non-dataflow block"; - } - ICHECK(match_shape->var->struct_info_.defined()); + } else if (const auto* match_cast = binding.as()) { + if (!cur_frame->is_dataflow) { + ICHECK(!match_cast->var.as()) + << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - ICHECK(match_shape->value->struct_info_.defined()); + ICHECK(match_cast->var->struct_info_.defined()); + ICHECK(match_cast->value->struct_info_.defined()); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); + } else { + LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); } } @@ -682,11 +677,12 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { - return this->VisitVarBinding(Downcast(binding)); + if (auto* var_binding = binding.as()) { + return this->VisitVarBinding(GetRef(var_binding)); } else { - ICHECK(binding.as()) << "expected VarBinding or MatchShape, got " << binding; - return this->VisitMatchShape(Downcast(binding)); + auto* match_cast = binding.as(); + ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); + return this->VisitMatchCast(GetRef(match_cast)); } } @@ -701,13 +697,13 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitExpr(binding->value); if (!new_value.same_as(binding->value)) { - binding = MatchShape(new_value, binding->pattern, binding->var, binding->span); + binding = MatchCast(binding->var, new_value, binding->struct_info, binding->span); } - if (binding->var.defined() && !binding->var->struct_info_.defined()) { - UpdateStructInfo(binding->var, GetStructInfo(new_value)); + if (!binding->var->struct_info_.defined()) { + UpdateStructInfo(binding->var, binding->struct_info); } return binding; } @@ -808,9 +804,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorIsInstance(); Array current; for (const Binding& binding : block->bindings) { - auto match_shape = binding.as(); - auto var_binding = binding.as(); - const Expr& value = match_shape ? match_shape->value : var_binding->value; + Expr value; + if (const auto* var_binding = binding.as()) { + value = var_binding->value; + } else if (const auto* match_cast = binding.as()) { + value = match_cast->value; + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } // if we encounter a nested seq, we have to flatten it: // 1. Append the binding block we've accumulated so far // 2. Reset the current block @@ -831,10 +832,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(MatchShape(seq->body, match_shape->pattern, match_shape->var)) - : Downcast(VarBinding(var_binding->var, seq->body))); + + if (const auto* var_binding = binding.as()) { + current.push_back(VarBinding(var_binding->var, seq->body)); + } else if (const auto* match_cast = binding.as()) { + current.push_back(MatchCast(match_cast->var, seq->body, match_cast->struct_info)); + } else { + LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + } } else { current.push_back(binding); } @@ -916,9 +921,9 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder bui return builder->Emit(expr); }); -TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShape") - .set_body_typed([](BlockBuilder builder, Expr value, Array pattern) { - return builder->EmitMatchShape(value, pattern); +TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast") + .set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) { + return builder->EmitMatchCast(value, struct_info); }); TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput") diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 48828689ad..3b4303a5c7 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -63,7 +63,7 @@ te::Tensor TETensor(Expr value, std::string name) { auto* shape_expr = tensor_sinfo->shape.as(); CHECK(shape_expr) << "ValueError: Expression does not have an known symbolic shape, please consider use " - "match_shape " + "match_cast " << "to constrain the shape before passing into te_tensor"; n->shape = shape_expr->values; n->dtype = tensor_sinfo->dtype; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 0f0ce72211..fb0e452062 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -322,22 +322,21 @@ TVM_REGISTER_GLOBAL("relax.Constant").set_body_typed([](runtime::NDArray data, S return Constant(data, span); }); -TVM_REGISTER_NODE_TYPE(BindingNode); +TVM_REGISTER_NODE_TYPE(MatchCastNode); -TVM_REGISTER_NODE_TYPE(MatchShapeNode); - -MatchShape::MatchShape(Expr value, Array pattern, Var var, Span span) { - ObjectPtr n = make_object(); - n->value = std::move(value); - n->pattern = std::move(pattern); +MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { + ObjectPtr n = make_object(); + ICHECK(var.defined()) << "MatchCast requires var to be defined"; n->var = std::move(var); + n->value = std::move(value); + n->struct_info = std::move(struct_info); n->span = span; data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.MatchShape") - .set_body_typed([](Expr value, Array pattern, Var var, Span span) { - return MatchShape(value, pattern, var, span); +TVM_REGISTER_GLOBAL("relax.MatchCast") + .set_body_typed([](Var var, Expr value, StructInfo struct_info, Span span) { + return MatchCast(var, value, struct_info, span); }); TVM_REGISTER_NODE_TYPE(VarBindingNode); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index d2cbe85623..a7bedc2ce2 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -83,24 +83,65 @@ namespace relax { // ================== // ExprVisitor +void ExprVisitor::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + default_struct_info_field_visitor_.VisitStructInfo(struct_info); +} + +ExprVisitor::DefaultStructInfoFieldVisitor::DefaultStructInfoFieldVisitor(ExprVisitor* parent) + : parent_(parent) {} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const Expr& expr) { + parent_->VisitExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const PrimExpr& expr) { + parent_->VisitPrimExpr(expr); +} + +void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. +} + void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } -void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const ConstantNode* op) { + this->VisitSpan(op->span); + // Constant's StructInfo does not depend on Expr. +} -void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo is not value-dep +} void ExprVisitor::VisitExpr_(const TupleNode* op) { this->VisitSpan(op->span); for (Expr field : op->fields) { this->VisitExpr(field); } + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } } // Visit the use-site of a defined Var -void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const VarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} // Visit the use-site of a defined DataflowVar -void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { + this->VisitSpan(op->span); + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } +} void ExprVisitor::VisitExpr_(const FunctionNode* op) { this->VisitSpan(op->span); @@ -109,6 +150,7 @@ void ExprVisitor::VisitExpr_(const FunctionNode* op) { } this->VisitExpr(op->body); + // FuncStructInfo does not depend on Expr. } void ExprVisitor::VisitExpr_(const CallNode* op) { @@ -122,6 +164,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { for (Expr arg : op->args) { this->VisitExpr(arg); } + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } } void ExprVisitor::VisitExpr_(const IfNode* op) { @@ -129,6 +175,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->cond); this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } } void ExprVisitor::VisitExpr_(const OpNode* op) {} @@ -136,6 +186,10 @@ void ExprVisitor::VisitExpr_(const OpNode* op) {} void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->tuple); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } } void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { @@ -143,9 +197,16 @@ void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { this->VisitPrimExpr(val); } this->VisitSpan(op->span); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } } -void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { this->VisitSpan(op->span); } +void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { + this->VisitSpan(op->span); + // FuncStructInfo does not depend on Expr. +} void ExprVisitor::VisitExpr_(const SeqExprNode* op) { this->VisitSpan(op->span); @@ -153,6 +214,10 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) { this->VisitBindingBlock(block); } this->VisitExpr(op->body); + + if (auto* sinfo = op->struct_info_.as()) { + this->VisitExprDepStructInfoField(GetRef(sinfo)); + } } void ExprVisitor::VisitType(const Type& t) {} @@ -177,14 +242,9 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(IfNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OpNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleGetItemNode); -void ExprVisitor::VisitBinding_(const MatchShapeNode* binding) { +void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); - // TODO(ziheng): should we change pattern from - // Array to ShapeExpr? - this->VisitExpr(ShapeExpr(binding->pattern)); - if (binding->var.defined()) { - this->VisitVarDef(binding->var); - } + this->VisitVarDef(binding->var); } void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { @@ -206,7 +266,7 @@ void ExprVisitor::VisitVarDef_(const VarNode* var) { this->VisitSpan(var->span); void ExprVisitor::VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { VisitBinding_(node); - } else if (const auto* node = binding.as()) { + } else if (const auto* node = binding.as()) { VisitBinding_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); @@ -257,11 +317,43 @@ TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr ex // ================== // ExprMutatorBase +StructInfo ExprMutatorBase::VisitExprDepStructInfoField(const StructInfo& struct_info) { + // recurse into struct info in case they depend on value + // under the current scope. + return default_struct_info_field_mutator_.VisitStructInfo(struct_info); +} + +ExprMutatorBase::DefaultStructInfoFieldMutator::DefaultStructInfoFieldMutator( + ExprMutatorBase* parent) + : parent_(parent) {} + +Expr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField(const Expr& expr) { + return parent_->VisitExpr(expr); +} + +PrimExpr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField( + const PrimExpr& expr) { + return parent_->VisitPrimExpr(expr); +} + +StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( + const FuncStructInfoNode* op) { + // Do not recurse into function struct info + // as they won't contain ref to values in current scope. + return GetRef(op); +} + Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } -Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { + // Constant' struct info won't be affected by Expr/PrimExpr change. + return GetRef(op); +} -Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { + // FuncStructInfo won't be affected by Expr/PrimExpr change. + return GetRef(op); +} Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { bool unchanged = true; @@ -273,20 +365,33 @@ Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { } if (unchanged) { + // If tuple's struct info change it means that + // one of its fields' struct info will change + // so un-changed already implies that struct info won't change return GetRef(op); } else { - Expr new_tuple = Tuple(fields, op->span); - return new_tuple; + // when there is a change return a new tuple node + return Tuple(fields, op->span); } } // Visit the use-site of a defined Var -Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} // Visit the use-site of a defined DataflowVar -Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { + // struct info of var-use should remain stable + // or the var itself will get replaced + return GetRef(op); +} Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { + // struct info of function is not value dependent + // so no need to check struct_info field Expr body = this->VisitExpr(op->body); if (body.same_as(op->body)) { @@ -314,11 +419,10 @@ Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { unchanged &= new_arg.same_as(arg); } - if (unchanged) { + if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { return GetRef(call_node); } else { - Expr new_call = Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); - return new_call; + return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); } } @@ -327,7 +431,8 @@ Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { Expr true_b = this->VisitExpr(op->true_branch); Expr false_b = this->VisitExpr(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && - op->false_branch.same_as(false_b)) { + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef(op); } else { return If(guard, true_b, false_b, op->span); @@ -339,6 +444,8 @@ Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef(op); } Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { auto t = this->VisitExpr(op->tuple); if (op->tuple.same_as(t)) { + // struct info can be deterministically derived by tuple and index + // if t does not change, then struct info won't change. return GetRef(op); } else { return TupleGetItem(t, op->index, op->span); @@ -349,13 +456,17 @@ Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); if (values.same_as(op->values)) { + // If values does not change, struct info won't change. return GetRef(op); } else { return ShapeExpr(values, op->span); } } -Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { return GetRef(op); } +Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { + // StructInfo of function remains value independent. + return GetRef(op); +} Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; @@ -370,11 +481,11 @@ Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { Expr body = this->VisitExpr(op->body); - if (all_blocks_unchanged && body.same_as(op->body)) { + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef(op); - } else { - return SeqExpr(blocks, body); } + return SeqExpr(blocks, body); } BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { @@ -384,10 +495,9 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { if (auto var_binding = binding.as()) { Expr new_value = this->VisitExpr(var_binding->value); bindings.push_back(VarBinding(var_binding->var, new_value)); - } else if (auto match_shape_binding = binding.as()) { - Expr new_value = this->VisitExpr(match_shape_binding->value); - bindings.push_back( - MatchShape(new_value, match_shape_binding->pattern, match_shape_binding->var)); + } else if (auto match_cast = binding.as()) { + Expr new_value = this->VisitExpr(match_cast->value); + bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); } @@ -414,23 +524,6 @@ Expr ExprMutator::VisitExpr(const Expr& expr) { return builder_->Normalize(ExprFunctor::VisitExpr(expr)); } -Expr ExprMutator::VisitExpr_(const TupleNode* op) { - bool unchanged = true; - tvm::Array fields; - for (Expr field : op->fields) { - Expr new_field = this->VisitExpr(field); - fields.push_back(new_field); - unchanged &= new_field.same_as(field); - } - - if (unchanged) { - return GetRef(op); - } else { - Expr new_tuple = Tuple(fields, op->span); - return new_tuple; - } -} - // Visit the use-site of a defined Var Expr ExprMutator::VisitExpr_(const VarNode* op) { auto it = var_remap_.find(op->vid); @@ -464,6 +557,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Expr body = this->VisitWithNewScope(op->body, params); + // FuncStructInfo does not depend on Expr if (all_params_unchanged && body.same_as(op->body)) { return GetRef(op); } else { @@ -476,7 +570,8 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr true_b = this->VisitWithNewScope(op->true_branch); Expr false_b = this->VisitWithNewScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && - op->false_branch.same_as(false_b)) { + op->false_branch.same_as(false_b) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef(op); } else { return If(guard, true_b, false_b, op->span); @@ -502,7 +597,8 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { all_blocks_unchanged = false; } - if (all_blocks_unchanged && body.same_as(op->body)) { + if (all_blocks_unchanged && body.same_as(op->body) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef(op); } else { return SeqExpr(blocks, body); @@ -542,36 +638,17 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { builder_->EmitNormalized(VarBinding(new_var, new_value)); } -void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { +void ExprMutator::VisitBinding_(const MatchCastNode* binding) { + Var new_var = this->VisitVarDef(binding->var); Expr new_value = this->VisitExpr(binding->value); - Expr new_pattern = this->VisitExpr(ShapeExpr(binding->pattern)); - - Var new_var; - if (binding->var.defined()) { - StructInfo new_sinfo = GetStructInfo(new_value); - - if (auto* ptr = new_sinfo.as()) { - new_sinfo = TensorStructInfo(new_pattern, ptr->dtype); - } - new_var = this->VisitVarDef(binding->var); - Var temp = WithStructInfo(new_var, new_sinfo); - if (!temp.same_as(new_var)) { - new_var = temp; - this->var_remap_[binding->var->vid] = new_var; - } - } - - // reemit old binding if nothing changes - if (new_value.same_as(binding->value) && new_pattern.same_as(binding->pattern)) { - if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) { - builder_->EmitNormalized(GetRef(binding)); - return; - } + // re-emit old binding if nothing changes + if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + new_value = builder_->NormalizeArgument(new_value); + builder_->EmitNormalized(MatchCast(new_var, new_value, binding->struct_info, binding->span)); } - - builder_->EmitNormalized( - MatchShape(new_value, Downcast(new_pattern)->values, new_var)); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { @@ -591,19 +668,35 @@ BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { } Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { - // If an Expr have struct info, they must already be normalized, - // This invariant is checked at the constructor location. - // to simplify our overall development complexity and keep var def - // stable by default. - return GetRef(var); + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return DataflowVar(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } } -Var ExprMutator::VisitVarDef_(const VarNode* var) { return GetRef(var); } +Var ExprMutator::VisitVarDef_(const VarNode* var) { + if (auto* sinfo = var->struct_info_.as()) { + StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef(sinfo)); + if (struct_info.same_as(var->struct_info_)) { + return GetRef(var); + } else { + return Var(var->vid, struct_info, var->span); + } + } else { + return GetRef(var); + } +} void ExprMutator::VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { VisitBinding_(node); - } else if (const auto* node = binding.as()) { + } else if (const auto* node = binding.as()) { VisitBinding_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index 5b2c01e54c..e984f58bbe 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -245,18 +246,12 @@ class DataflowBlockMutator : public ExprMutator { Map global_scope_vars; Map symbolic_vars; for (const Binding& binding : n->bindings) { - Var var; - if (const auto* node = binding.as()) { - var = node->var; - } else if (const auto* node = binding.as()) { - var = node->var; - for (PrimExpr expr : node->pattern) { - if (const tir::VarNode* sym_var = expr.as()) { - symbolic_vars.Set(sym_var->name_hint, Downcast(expr)); - } + Var var = binding->var; + if (const auto* match_cast = binding.as()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + symbolic_vars.Set(var->name_hint, GetRef(var)); } - } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); } if (!var.as()) { global_scope_vars.Set(var->name_hint(), var); @@ -269,23 +264,17 @@ class DataflowBlockMutator : public ExprMutator { // raise error if there are updates of recorded Global Scope Vars and Symbolic Vars for (const Binding& binding : updated_block->bindings) { - Var var; - if (const auto* node = binding.as()) { - var = node->var; - } else if (const auto* node = binding.as()) { - var = node->var; - for (PrimExpr expr : node->pattern) { - if (const tir::VarNode* sym_var = expr.as()) { - if (symbolic_vars.count(sym_var->name_hint) > 0) { - tir::Var old_var = symbolic_vars[sym_var->name_hint]; - ICHECK(expr.same_as(old_var)) - << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; - symbolic_vars.erase(sym_var->name_hint); - } + Var var = binding->var; + if (const auto* match_cast = binding.as()) { + auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); + for (const tir::VarNode* var : collected_vars) { + if (symbolic_vars.count(var->name_hint) > 0) { + tir::Var old_var = symbolic_vars[var->name_hint]; + ICHECK(var == old_var.get()) + << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; + symbolic_vars.erase(var->name_hint); } } - } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); } if (!var.as() && global_scope_vars.count(var->name_hint()) > 0) { ICHECK(var.same_as(global_scope_vars[var->name_hint()])) @@ -300,6 +289,25 @@ class DataflowBlockMutator : public ExprMutator { } private: + class SymbolicVarCollector : public StructInfoVisitor { + public: + static std::unordered_set Collect(const StructInfo& info) { + SymbolicVarCollector collector; + collector.VisitStructInfo(info); + return std::move(collector.symbolic_vars_); + } + + private: + void VisitStructInfoExprField(const PrimExpr& expr) final { + if (const tir::VarNode* sym_var = expr.as()) { + symbolic_vars_.insert(sym_var); + } + } + + private: + std::unordered_set symbolic_vars_; + }; + runtime::TypedPackedFunc pass_func_; IRModule mod_; PassContext pass_ctx_; diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index ac0168846d..c5c02434cd 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -20,10 +20,10 @@ #include #include #include +#include #include #include #include -#include #include #include diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index b0af3d7d4f..f12fb79fdb 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -68,34 +68,19 @@ class BindingCanonicalizer : public ExprMutator { this->builder_->EmitNormalized(VarBinding(new_var, new_value)); } - void VisitBinding_(const MatchShapeNode* binding) override { + void VisitBinding_(const MatchCastNode* binding) override { // If we have a trivial shape check (the shape_ of LHS and RHS is the same), // we can canonicalize to a var binding Expr new_value = this->VisitExpr(binding->value); - Var new_var; - // since we do not permit the checked_type to change and don't make any changes - // to the shape pattern, there is no reason to do any more checking like in the - // original mutator - if (binding->var.defined()) { - new_var = this->VisitVarDef(binding->var); + // if the LHS and RHS have the same struct info, we canonicalize to a var binding instead + if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) { + builder_->EmitNormalized(VarBinding(binding->var, new_value)); + } else if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info)); } - - // if the LHS and RHS have the same shape_, we canonicalize to a var binding instead - if (new_var.defined() && StructuralEqual()(GetStructInfo(new_var), GetStructInfo(new_value))) { - builder_->EmitNormalized(VarBinding(new_var, new_value)); - return; - } - - // reemit old binding if nothing changes - if (new_value.same_as(binding->value)) { - if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) { - builder_->EmitNormalized(GetRef(binding)); - return; - } - } - - builder_->EmitNormalized(MatchShape(new_value, binding->pattern, new_var)); } private: diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 2e9465bbe1..0983db3989 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -136,7 +136,7 @@ class GraphCreator : public ExprVisitor { // We skip ordinary binding blocks since they might be impure (with side effect or control flow) } - // TODO(tvm-team): how to deal with MatchShape binding here + // TODO(tvm-team): how to deal with MatchCast binding here void VisitBinding_(const VarBindingNode* binding) final { IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); @@ -394,7 +394,7 @@ class FunctionCreator : public ExprMutator { AppendOutput(var_binding->var); } } else { - // TODO(tvm-team): handle match_shape + // TODO(tvm-team): handle match_cast } bindings_.push_back(binding); } @@ -416,14 +416,7 @@ class FunctionCreator : public ExprMutator { // Step 2. Visit each binding and collect outputs one by one. Array outputs; for (const Binding& binding : bindings_) { - const VarNode* var = nullptr; - if (const auto* var_binding = binding.as()) { - var = var_binding->var.get(); - } else if (const auto* match_shape = binding.as()) { - var = match_shape->var.get(); - } else { - ICHECK(false); - } + const VarNode* var = binding->var.get(); if (output_vars_.count(var)) { // Case 1. It is an output binding // We only allow VarBinding as output. @@ -687,12 +680,13 @@ class OperatorFusor : public ExprMutator { } } }; + if (const auto* var_binding = binding.as()) { PostOrderVisit(var_binding->value, update_boundary); } else { - const auto* match_shape = binding.as(); - ICHECK_NOTNULL(match_shape); - PostOrderVisit(match_shape->value, update_boundary); + const auto* match_cast = binding.as(); + ICHECK_NOTNULL(match_cast); + PostOrderVisit(match_cast->value, update_boundary); } } } @@ -703,14 +697,7 @@ class OperatorFusor : public ExprMutator { * \return The pointer to the group which the input binding is in */ GraphPartitioner::Group* GetGroupFromBinding(const Binding& binding) { - Var var{nullptr}; - if (const auto* var_binding = binding.as()) { - var = var_binding->var; - } else { - const auto* match_shape = binding.as(); - ICHECK(match_shape != nullptr); - var = match_shape->var; - } + Var var = binding->var; return GetGroupFromVar(var); } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 15dc33f9c9..d77a8bd576 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -289,9 +289,8 @@ class FusedTIRConstructor : public ExprVisitor { } } - void VisitBinding_(const MatchShapeNode* match_shape) final { - // TODO(relax-team): support match shape in primitive functions; - LOG(FATAL) << "MatchShape is unsupported in primitive functions"; + void VisitBinding_(const MatchCastNode* match_cast) final { + LOG(FATAL) << "MatchCast is unsupported in primitive functions"; } void VisitExpr_(const CallNode* call) final { diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 335e2fc7a1..679cd55e46 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -132,7 +132,7 @@ class NormalizeMutator : public ExprMutatorBase { void VisitBinding(const Binding& binding) { if (const auto* node = binding.as()) { VisitBinding_(node); - } else if (const auto* node = binding.as()) { + } else if (const auto* node = binding.as()) { VisitBinding_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); @@ -152,18 +152,14 @@ class NormalizeMutator : public ExprMutatorBase { } } - void VisitBinding_(const MatchShapeNode* binding) { + void VisitBinding_(const MatchCastNode* binding) { Expr new_value = this->VisitExpr(binding->value); - if (binding->var.defined()) { - if (!binding->var->struct_info_.defined()) { - UpdateStructInfo(binding->var, GetStructInfo(new_value)); - } - } if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + builder_->EmitNormalized(GetRef(binding)); } else { - builder_->EmitNormalized(MatchShape(new_value, binding->pattern, binding->var)); + builder_->EmitNormalized( + MatchCast(binding->var, builder_->NormalizeArgument(new_value), binding->struct_info)); } } diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index cb18f0fe4e..c0b10a37fe 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -262,18 +262,15 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) { return doc; } -Doc RelaxScriptPrinter::VisitNode_(const relax::MatchShapeNode* op) { +Doc RelaxScriptPrinter::VisitNode_(const relax::MatchCastNode* op) { Doc doc; - if (op->var.defined()) { - doc << Print(op->var); - if (const auto& sinfo = MatchStructInfo(op->var)) { - doc << ": " << Print(sinfo); - } - doc << " = "; + doc << Print(op->var); + if (const auto& sinfo = MatchStructInfo(op->var)) { + doc << ": " << Print(sinfo); } - doc << "R.match_shape("; - // TODO(@altanh): maybe op->pattern should just be a ShapeExpr? - doc << Print(op->value) << ", " << Print(relax::ShapeExpr(op->pattern)); + doc << " = "; + doc << "R.match_cast("; + doc << Print(op->value) << ", " << Print(op->struct_info); doc << ")"; return doc; } @@ -318,12 +315,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { std::vector return_vars; for (const relax::Binding& binding : op->bindings) { body << Print(binding) << Doc::NewLine(); - Var var; - if (const relax::VarBindingNode* var_binding = binding.as()) { - var = var_binding->var; - } else if (const relax::MatchShapeNode* shape_binding = binding.as()) { - var = shape_binding->var; - } + Var var = binding->var; if (var.defined() && !var.as()) { return_vars.push_back(Print(var)); } diff --git a/src/relay/printer/text_printer.h b/src/relay/printer/text_printer.h index 42b95db854..ad22c738b4 100644 --- a/src/relay/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -276,7 +276,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitNode_(const relax::VarNode* op) override; Doc VisitNode_(const relax::DataflowVarNode* op) override; Doc VisitNode_(const relax::ShapeExprNode* op) override; - Doc VisitNode_(const relax::MatchShapeNode* op) override; + Doc VisitNode_(const relax::MatchCastNode* op) override; Doc VisitNode_(const relax::VarBindingNode* op) override; Doc VisitNode_(const relax::BindingBlockNode* op) override; Doc VisitNode_(const relax::DataflowBlockNode* op) override; diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 07a4cc8e06..72c111e67e 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -160,17 +160,8 @@ void BlockFrameNode::ExitWithScope() { // Step 3.2. Collect global vars' reference in bindings Map new_global_vars; for (const tvm::relax::Binding& binding : block->bindings) { - if (const auto* var_binding = binding.as()) { - if (!var_binding->var->IsInstance()) { - new_global_vars.Set(var_binding->var->vid, var_binding->var); - } - } else if (const auto* match_shape = binding.as()) { - if (match_shape->var.defined() && - !match_shape->var->IsInstance()) { - new_global_vars.Set(match_shape->var->vid, match_shape->var); - } - } else { - LOG(FATAL) << "ValueError: Unsupported binding type: " << binding; + if (!binding->var->IsInstance()) { + new_global_vars.Set(binding->var->vid, binding->var); } } diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index da0531ac98..2c8b17460c 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -204,28 +204,18 @@ tvm::relax::Var Emit(const tvm::relax::Expr& expr) { return var; } -Optional EmitMatchShape(const tvm::relax::Expr& value, // - const Array& pattern, // - bool emit_var) { +tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value, + const tvm::relax::StructInfo& struct_info) { BlockFrame block_frame = CheckBlockFrameExistAndUnended(); - tvm::relax::BlockBuilder block_builder = GetBlockBuilder(); - - if (!emit_var) { - // If we don't intend to emit a variable, just emit the binding and return. - tvm::relax::MatchShape match_shape(block_builder->Normalize(value), pattern, - tvm::relax::Var{nullptr}); - block_builder->EmitNormalized(match_shape); - return NullOpt; - } else { - // Otherwise, we need to emit a variable and bind it to the match shape. - tvm::relax::Var var = block_builder->EmitMatchShape(value, pattern); - block_frame->emitted_vars.push_back(var); - return var; - } + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + + tvm::relax::Var var = block_builder->EmitMatchCast(value, struct_info); + block_frame->emitted_vars.push_back(var); + return var; } TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); -TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(EmitMatchShape); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchCast").set_body_typed(EmitMatchCast); ///////////////////////////// Type Deduce ////////////////////////////// diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index 3a72eb8e36..019532071d 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -90,12 +90,11 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String << "A non-dataflow var is expected in the last binding of '" << method << "'."; body = var_binding->value; *var_name = var_binding->var->name_hint(); - } else if (const auto* match_shape = last_binding.as()) { - CHECK(match_shape->var.defined() && - !match_shape->var->IsInstance()) + } else if (const auto* match_cast = last_binding.as()) { + CHECK(!match_cast->var->IsInstance()) << "A non-dataflow var is expected in the last binding of '" << method << "'."; body = var_binding->value; - *var_name = match_shape->var->name_hint(); + *var_name = match_cast->var->name_hint(); } else { ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); } diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 2e638bfd32..ee235cc6e9 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -267,11 +267,11 @@ def func(a: R.Tensor) -> R.Tensor: def main(x: R.Tensor, y: R.Tensor) -> R.Tensor: z = R.add(x, y) # no binding here - R.match_shape(x, (5, 5)) + _ = R.match_cast(x, R.Tensor((5, 5))) with R.dataflow(): q = R.add(z, z) p = func(q) - r = R.match_shape(p, (5, 5)) + r = R.match_cast(p, R.Tensor((5, 5))) s = r R.output(s) return s @@ -285,7 +285,7 @@ def test_all_vars(): assert vars[1] == VarExample["func"].body.body var_names = var_name_set(all_vars(VarExample["main"])) - assert var_names == {"x", "y", "z", "p", "q", "r", "s"} + assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"} def test_bound_vars(): @@ -297,11 +297,11 @@ def test_bound_vars(): # all the vars are bound var_names = var_name_set(bound_vars(VarExample["main"])) - assert var_names == {"x", "y", "z", "p", "q", "r", "s"} + assert var_names == {"_", "x", "y", "z", "p", "q", "r", "s"} # if we consider only the body, then the function arguments are not bound body_names = var_name_set(bound_vars(VarExample["main"].body)) - assert body_names == {"z", "p", "q", "r", "s"} + assert body_names == {"_", "z", "p", "q", "r", "s"} # only binding is in the (normalized) body simple_body_vars = bound_vars(VarExample["func"].body) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 9f4cea057b..f4cf5d98ac 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -106,15 +106,15 @@ def test_dataflow_var() -> None: assert "checked_type_" in v1_annos -def test_match_shape() -> None: - # match_shape([16, 8], [m, n]) +def test_match_cast() -> None: + # match_cast([16, 8], [m, n]) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") var = rx.Var("v0", R.Shape()) - b0 = rx.MatchShape(shape, [m, n], var) + b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) b0_str = dump_ast(b0) - assert b0_str.startswith("MatchShape(") + assert b0_str.startswith("MatchCast(") assert "Constant" in b0_str assert "PrimExpr(value=`m: int64`)" in b0_str assert "PrimExpr(value=`n: int64`)" in b0_str @@ -123,12 +123,12 @@ def test_match_shape() -> None: assert b0_str != dump_ast(b0, include_type_annotations=False) # var1: Tensor((m, n), "float32") = - # match_shape(var0: R.Tensor("float32"), [m, n]) + # match_cast(var0: R.Tensor("float32"), [m, n]) value = rx.Var("value", R.Tensor("float32")) var = rx.Var("v1", R.Tensor([m, n], "float32")) - b1 = rx.MatchShape(value, [m, n], var) + b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32")) b1_str = dump_ast(b1) - assert b1_str.startswith("MatchShape(") + assert b1_str.startswith("MatchCast(") assert "PrimExpr(value=`m: int64`)" in b1_str assert "PrimExpr(value=`n: int64`)" in b1_str assert b1_str != dump_ast( @@ -136,20 +136,6 @@ def test_match_shape() -> None: ) -def test_match_shape_unbound() -> None: - @R.function - def func(x: R.Tensor) -> R.Tensor: - R.match_shape(x, (1, 1)) - return x - - # no var field on the match shape! - func_str = strip_whitespace(dump_ast(func)) - assert "MatchShape" in func_str - assert "value=Var(" in func_str - assert "pattern=[PrimExpr(" in func_str - assert "var=" not in func_str - - def test_var_binding() -> None: v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -162,10 +148,10 @@ def test_var_binding() -> None: def test_binding_block() -> None: - m = tir.Var("m", dtype="int32") - n = tir.Var("n", dtype="int32") + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -176,15 +162,15 @@ def test_binding_block() -> None: assert block0_str.startswith("BindingBlock(") assert "bindings=" in block0_str assert "VarBinding(" in block0_str - assert "MatchShape(" in block0_str + assert "MatchCast(" in block0_str assert '"v0"' in block0_str def test_dataflow_block() -> None: - m = tir.Var("m", dtype="int32") - n = tir.Var("n", dtype="int32") + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -195,7 +181,7 @@ def test_dataflow_block() -> None: assert block0_str.startswith("DataflowBlock(") assert "bindings=" in block0_str assert "VarBinding(" in block0_str - assert "MatchShape(" in block0_str + assert "MatchCast(" in block0_str assert '"v0"' in block0_str diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py index cca5e8bbde..8e3a75615d 100644 --- a/tests/python/relax/test_autotir_integration.py +++ b/tests/python/relax/test_autotir_integration.py @@ -66,9 +66,9 @@ def tir_relu(x:T.handle, y:T.handle): def main(x:R.Tensor((m,n), "float32"), w:R.Tensor((n,k), "float32")) -> R.Tensor: with R.dataflow(): sh = R.call_packed("vm.builtin.shape_of", x) - x0 = R.match_shape(sh, (m, n)) + x0 = R.match_cast(sh, R.Tensor((m, n), "float32")) sh1 = R.call_packed("vm.builtin.shape_of", w) - x1 = R.match_shape(sh1, (n, k)) + x1 = R.match_cast(sh1, R.Tensor((n, k), "float32")) lv0 = R.call_tir(tir_matmul, (x, w), (m, k), dtype="float32") lv1 = R.call_tir(tir_relu, (lv0), (m, k), dtype="float32) R.output(lv1) diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 4ff524096a..263d09abc6 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -232,7 +232,7 @@ def test_binary_shape_type_deduction(): assert gv0.checked_type.dtype == "float16" -def test_emit_match_shape(): +def test_emit_match_cast(): m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") x = rx.Var("tensor_value", R.Tensor("float32", ndim=-1)) @@ -242,55 +242,54 @@ def test_emit_match_shape(): with bb.function("func", [x, y]): with bb.dataflow(): # lv0: Tensor((m, n), "float32") = - # match_shape(x: Tensor(_, "float32"], [m, n)) - lv0 = bb.match_shape(x, [m, n]) + # match_cast(x: Tensor(_, "float32"], [m, n)) + lv0 = bb.match_cast(x, R.Tensor([m, n], "float32")) assert isinstance(lv0, rx.DataflowVar) assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32")) - # lv1: Shape = match_shape(shape, [m, n]) - lv1 = bb.match_shape(y, [m, n]) - assert lv1.struct_info == rx.ShapeStructInfo(ndim=2) + # lv1: Shape = match_cast(shape, R.Shape([m, n])) + lv1 = bb.match_cast(y, R.Shape([m, n])) + assert lv1.struct_info == rx.ShapeStructInfo([m, n]) gv0 = bb.emit_output(lv1) bb.emit_func_output(gv0) func = bb.get()["func"] block = func.body.blocks[0] b0, b1 = block.bindings[:2] - assert isinstance(b0, rx.MatchShape) - assert isinstance(b1, rx.MatchShape) + assert isinstance(b0, rx.MatchCast) + assert isinstance(b1, rx.MatchCast) assert b0.value == x - assert b0.pattern[0] == m - assert b0.pattern[1] == n + assert b0.struct_info == rx.TensorStructInfo([m, n], "float32") assert b0.var == lv0 assert b1.value == y - assert b1.pattern[0] == m - assert b1.pattern[1] == n + assert b1.struct_info == rx.ShapeStructInfo([m, n]) assert b1.var == lv1 -def test_emit_match_shape_binding_in_dataflow_block(): +def test_emit_match_cast_binding_in_dataflow_block(): bb = rx.BlockBuilder() x = rx.Var("x", R.Tensor("float32", ndim=-1)) m = tir.Var("m", dtype="int64") gv = rx.Var("gv", R.Tensor("float32", ndim=-1)) - match_shape = rx.MatchShape(x, (m,), gv) + match_cast = rx.MatchCast(gv, x, R.Tensor((m,), "float32")) with bb.function("main", [x]): with bb.dataflow(): - bb.emit_normalized(match_shape) + bb.emit_normalized(match_cast) bb.emit_output(gv) bb.emit_func_output(x) func = bb.get()["main"] block = func.body.blocks[0] b0 = block.bindings[0] - assert isinstance(b0, rx.MatchShape) + assert isinstance(b0, rx.MatchCast) assert b0.value == x - assert b0.pattern[0] == m + assert isinstance(b0.struct_info, rx.TensorStructInfo) + assert b0.struct_info.shape[0] == m assert b0.var == gv diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 55d7ee072c..234752392f 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -22,6 +22,22 @@ from tvm.script import relax as R +def _check_equal(x, y, map_free_vars=False): + tvm.ir.assert_structural_equal(x, y, map_free_vars) + tvm.ir.assert_structural_equal(y, x, map_free_vars) + + xhash = tvm.ir.structural_hash(x, map_free_vars) + yhash = tvm.ir.structural_hash(y, map_free_vars) + + assert xhash == yhash + + +def _check_json_roundtrip(x): + xret = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, xret, map_free_vars=True) + return xret + + def test_var() -> None: v0 = rx.Var("v0") assert v0.name_hint == "v0" @@ -51,13 +67,13 @@ def test_dataflow_var() -> None: tvm.ir.assert_structural_equal(v1.struct_info, rx.TensorStructInfo(shape, "float16")) -def test_match_shape() -> None: - # match_shape([16, 8], [m, n]) +def test_match_cast() -> None: + # match_cast([16, 8], [m, n]) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") var = rx.Var("v0", R.Shape()) - b0 = rx.MatchShape(shape, [m, n], var) + b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) assert b0.value == shape assert b0.pattern[0] == m assert b0.pattern[1] == n @@ -65,11 +81,11 @@ def test_match_shape() -> None: assert b0.var.checked_type == rx.ShapeType() # var1: R.Tensor((m, n), "float32") = - # match_shape(var0: R.Tensor("float32", ndim=-1), [m, n]) + # match_cast(var0: R.Tensor("float32", ndim=-1), R.Tensor((m, n), "float32")) value = rx.Var("value", R.Tensor("float32", ndim=-1)) var = rx.Var("v1", R.Tensor([m, n], "float32")) - b1 = rx.MatchShape(value, [m, n], var) + b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32")) assert b1.value == value assert b1.pattern[0] == m assert b1.pattern[1] == n @@ -77,6 +93,17 @@ def test_match_shape() -> None: assert b1.var.checked_type == rx.DynTensorType(2, "float32") +def test_match_cast() -> None: + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + ivalue = rx.Var("input_value") + sinfo = rx.TensorStructInfo([n, m], "float32") + b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo) + assert b0.value.same_as(ivalue) + assert b0.struct_info == sinfo + _check_json_roundtrip(b0) + + def test_var_binding() -> None: v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -86,10 +113,10 @@ def test_var_binding() -> None: def test_binding_block() -> None: - m = tir.Var("m", dtype="int32") - n = tir.Var("n", dtype="int32") + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) @@ -101,10 +128,10 @@ def test_binding_block() -> None: def test_dataflow_block() -> None: - m = tir.Var("m", dtype="int32") - n = tir.Var("n", dtype="int32") + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") - b0 = rx.MatchShape(shape, [m, n], rx.Var("v0")) + b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) v0 = rx.Var("v0") val = rx.const(np.random.rand(24, 56)) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index a0abaa72d7..95e057acbf 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -30,7 +30,7 @@ Function, GlobalVar, If, - MatchShape, + MatchCast, SeqExpr, ShapeExpr, Tuple, @@ -155,13 +155,11 @@ def visit_var_binding_(self, binding: VarBinding) -> None: self.visit_var_def(binding.var) self.log.pop_scope() - def visit_match_shape_(self, binding: MatchShape) -> None: - self.log.add("MatchShape") + def visit_match_cast_(self, binding: MatchCast) -> None: + self.log.add("MatchCast") self.log.push_scope() + self.visit_var_def(binding.var) self.visit_expr(binding.value) - self.visit_expr(ShapeExpr(binding.pattern)) - if binding.var: - self.visit_var_def(binding.var) self.log.pop_scope() def visit_binding_block_(self, block: BindingBlock) -> None: @@ -280,31 +278,18 @@ def visit_var_binding_(self, binding: VarBinding) -> None: self.builder_.emit_normalized(VarBinding(new_var, new_value)) - def visit_match_shape_(self, binding: MatchShape) -> None: - """Identical with ExprMutator::VisitBinding_(const MatchShapeNode* binding) on the C++ side.""" + def visit_match_cast_(self, binding: MatchCast) -> None: + """Identical with ExprMutator::VisitBinding_(const MatchCastNode* binding) on the C++ side.""" + new_var = self.visit_var_def(binding.var) new_value = self.visit_expr(binding.value) - new_pattern = self.visit_expr(ShapeExpr(binding.pattern)) - - if binding.var: - new_sinfo = None - if isinstance(new_value.struct_info, TensorStructInfo): - new_sinfo = relax.TensorStructInfo(new_pattern, dtype=new_value.struct_info) - else: - new_sinfo = new_value.struct_info - new_var = self.visit_var_def(binding.var) - temp = self.with_struct_info(new_var, new_sinfo) - if not temp.same_as(new_var): - new_var = temp - self.set_var_remap(binding.var.vid, new_var) - - self.log.add("MatchShape") - if binding.value.same_as(new_value) and binding.pattern.same_as(new_pattern): - if not binding.var or (binding.var and binding.var.same_as(new_var)): - self.builder_.emit_normalized(binding) - return + temp = self.with_struct_info(new_var, binding.struct_info) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) - self.builder_.emit_normalized(MatchShape(new_value, new_pattern.values, new_var)) + self.log.add("MatchCast") + self.builder_.emit_normalized(MatchCast(new_var, new_value, binding.struct_info)) def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" @@ -372,7 +357,6 @@ def test_var(): basic_check(x, "Var", "Var") -@pytest.mark.skip("Revisit PyMutator tests after struct info") def test_dataflow_var(): lv = relax.DataflowVar("lv", R.Tensor([n], "float32")) basic_check(lv, "DataflowVar", "DataflowVar") @@ -418,7 +402,7 @@ def test_call(): basic_check( call_node, "\n".join(["Call", "\tOp", "\tVar", "\tVar"]), - "\n".join(["Op", "Var", "Var", "Call"]), + "\n".join(["Op", "Var", "Var", "ShapeExpr", "Call"]), ) @@ -443,7 +427,7 @@ def test_tuple_getitem(): def test_binding_block(): bb._begin_binding_block() gv0 = bb.emit(relax.op.add(x, y)) - gv1 = bb.match_shape(y, [m, n]) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) b0 = bb._end_block() basic_check( b0, @@ -456,10 +440,9 @@ def test_binding_block(): "\t\t\tVar", "\t\t\tVar", "\t\tVarDef", - "\tMatchShape", - "\t\tVar", - "\t\tShapeExpr", + "\tMatchCast", "\t\tVarDef", + "\t\tVar", ] ), "\n".join( @@ -475,7 +458,7 @@ def test_binding_block(): "ShapeExpr", "ShapeExpr", "VarDef", - "MatchShape", + "MatchCast", "BindingBlock", ] ), @@ -485,7 +468,7 @@ def test_binding_block(): def test_dataflow_block(): bb._begin_dataflow_block() lv0 = bb.emit(relax.op.add(x, y)) - gv1 = bb.match_shape(y, [m, n]) + gv1 = bb.match_cast(y, R.Tensor([m, n], "float32")) b0 = bb._end_block() basic_check( b0, @@ -498,10 +481,9 @@ def test_dataflow_block(): "\t\t\tVar", "\t\t\tVar", "\t\tDataflowVarDef", - "\tMatchShape", - "\t\tVar", - "\t\tShapeExpr", + "\tMatchCast", "\t\tDataflowVarDef", + "\t\tVar", ] ), "\n".join( @@ -517,7 +499,7 @@ def test_dataflow_block(): "ShapeExpr", "ShapeExpr", "DataflowVarDef", - "MatchShape", + "MatchCast", "DataflowBlock", ] ), diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 8f013923d3..07aa4a22aa 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -90,7 +90,7 @@ def f( assert len(o_call_packed.type_args) == 1 -def test_mismatch_shape_dims_and_ndim(): +def test_mismatch_cast_dims_and_ndim(): with pytest.raises(Exception): # TODO: replace with DiagnosticError once we have better error reporting. # with pytest.raises(tvm.error.DiagnosticError): @@ -160,16 +160,16 @@ def f(x: R.Tensor(("m", "n"), "float32")): return relax.call_tir("foo", (x,), (T.max(m),), dtype="float32") -def test_match_shape(): +def test_match_cast(): @R.function def f(x: R.Tensor(dtype="float32")): n, m = T.var("int64"), T.var("int64") - R.match_shape(R.shape_of(x), (n, m)) + _ = R.match_cast(R.shape_of(x), R.Shape((n, m))) y: R.Tensor((n, m), "float32") = R.add(x, x) return x match_sh = f.body.blocks[0].bindings[0] - pattern, value = match_sh.pattern, match_sh.value + value = match_sh.value check_call(value, "relax.shape_of", [f.params[0]]) @@ -378,15 +378,15 @@ def f(x: R.Tensor): assert f.body.body == t -def test_dataflow_match_shape(): +def test_dataflow_match_cast(): @R.function def f(x: R.Tensor): n, m = T.var("int64"), T.var("int64") with R.dataflow(): - x2: R.Tensor((n, m)) = R.match_shape(x, (n, m)) + x2: R.Tensor((n, m)) = R.match_cast(x, R.Tensor((n, m))) y = R.add(x2, x2) z = R.multiply(y, x) - R.match_shape(R.shape_of(z), (n, m)) + _ = R.match_cast(R.shape_of(z), R.Shape((n, m))) w: R.Tensor((n, m)) = R.add(z, x) R.output(y, w, x2) t: R.Tensor((n, m)) = R.multiply(y, w) @@ -396,7 +396,6 @@ def f(x: R.Tensor): x = f.params[0] df_block = f.body.blocks[0] x2_bind = df_block.bindings[0] - z_shape_bind = df_block.bindings[3] q_bind = f.body.blocks[1].bindings[1] assert x2_bind.var.name_hint == "x2" diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index d58a2063f2..6cdb53c908 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -64,11 +64,11 @@ def foo( check_roundtrip(foo) -def test_match_shape(): +def test_match_cast(): @R.function def foo(x: R.Tensor(dtype="float32")): n, m = T.var("int64"), T.var("int64") - R.match_shape(R.shape_of(x), (n, m)) + _ = R.match_cast(R.shape_of(x), R.Shape((n, m))) y: R.Tensor((n, m), "float32") = R.add(x, x) return x @@ -141,15 +141,15 @@ def foo(x: R.Tensor(ndim=2)): check_roundtrip(foo) -def test_dataflow_match_shape(): +def test_dataflow_match_cast(): @R.function def foo(x: R.Tensor(ndim=2)): n, m = T.var("int64"), T.var("int64") with R.dataflow(): - x2: R.Tensor((n, m)) = R.match_shape(x, (n, m)) + x2: R.Tensor((n, m)) = R.match_cast(x, R.Tensor((n, m))) y = R.add(x2, x2) z = R.multiply(y, x) - R.match_shape(R.shape_of(z), (n, m)) + _ = R.match_cast(R.shape_of(z), R.Shape((n, m))) w: R.Tensor((n, m)) = R.add(z, x) R.output(y, w, x2) t: R.Tensor((n, m)) = R.multiply(y, w) diff --git a/tests/python/relax/test_structual_equal_hash.py b/tests/python/relax/test_structual_equal_hash.py index 43db38fe80..8c890d27c0 100644 --- a/tests/python/relax/test_structual_equal_hash.py +++ b/tests/python/relax/test_structual_equal_hash.py @@ -53,14 +53,14 @@ def generator(x, y): _check_equal(block0, block1) -def test_match_shape(): +def test_match_cast(): x = rx.Var("x", R.Tensor([10])) m = tir.Var("m", dtype="int64") def generator(x): bb = rx.BlockBuilder() bb._begin_binding_block() - bb.match_shape(x, [m * 2]) + bb.match_cast(x, R.Tensor([m * 2])) return bb._end_block() block0 = generator(x) @@ -108,13 +108,13 @@ def generator(): _check_equal(mod0, mod1) -def test_match_shape_symbolic(): +def test_match_cast_symbolic(): @tvm.script.ir_module class InputModule: @R.function def f(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") - x0 = R.match_shape(x, (n, m)) + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) return (x0, (n + 1, m)) _check_save_roundtrip(InputModule) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 6bdcd0701c..7368cdaf57 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -176,15 +176,16 @@ def main(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): relax.transform.FailTestRewrite()(TestRewriteGlobalScopeVar) # raise error on rewriting/removing existing Symbolic Vars inside the dataflow block - # check all Symbolic Vars defined in R.match_shape + # check all Symbolic Vars defined in R.match_cast with pytest.raises(tvm.TVMError): @tvm.script.ir_module class TestRewriteSymbolicVar: @R.function def main(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): + m, n = T.var("int64"), T.var("int64") with R.dataflow(): - lv0 = R.match_shape(x, (m, n)) + lv0 = R.match_cast(x, R.Tensor((m, n), "float32")) gv0 = R.add(lv0, y) R.output(gv0) return gv0 @@ -197,8 +198,9 @@ def main(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): class TestRemoveSymbolicVar: @R.function def main(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): + m, n, d = T.var("int64"), T.var("int64"), T.var("int64") with R.dataflow(): - lv0 = R.match_shape(x, (m, n, d)) + lv0 = R.match_cast(x, R.Tensor((m, n, d), "float32")) gv0 = R.add(lv0, y) R.output(gv0) return gv0 @@ -324,7 +326,7 @@ class TestVMShapeLower: @R.function def foo(x: R.Tensor(dtype="float32")): m, n = T.var("int64"), T.var("int64") - R.match_shape(x, (n, m)) + _ = R.match_cast(x, R.Tensor((n, m), "float32")) return (n * 2, m * 3) mod = TestVMShapeLower diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index bd80236406..1c659e13ad 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -135,14 +135,14 @@ def main(x: R.Tensor) -> R.Object: assert_structural_equal(new_mod, Expected) -def test_match_shape(): +def test_match_cast(): @tvm.script.ir_module - class TestMatchShape: + class TestMatchCast: @R.function def main(x: R.Tensor): q = x m, n = T.var("int64"), T.var("int64") - z = R.match_shape(q, (m, n)) + z = R.match_cast(q, R.Tensor((m, n))) w = z return w @@ -153,11 +153,11 @@ def main(x: R.Tensor): q = x # can't get rid of z because its shape_ is different from x's m, n = T.var("int64"), T.var("int64") - z = R.match_shape(x, (m, n)) + z = R.match_cast(x, R.Tensor((m, n))) w = z return z - new_mod = relax.transform.CanonicalizeBindings()(TestMatchShape) + new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast) assert_structural_equal(new_mod, Expected) @@ -165,11 +165,11 @@ def test_same_shape(): @tvm.script.ir_module class TestSameShape: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") y = x # trivial check - z = R.match_shape(x, (m, n)) + z = R.match_cast(x, R.Tensor((m, n), "float32")) w = z q = R.add(w, y) return R.add(q, w) @@ -177,7 +177,7 @@ def main(x: R.Tensor(("m", "n"))): @tvm.script.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(("m", "n"), "float32")): m, n = T.var("int64"), T.var("int64") y = x # canonicalized into a var binding @@ -198,7 +198,7 @@ def main(x: R.Tensor(("m", "n"))): y = x # not trivial: introduces new shape vars o, p = T.var("int64"), T.var("int64") - z = R.match_shape(x, (o, p)) + z = R.match_cast(x, R.Tensor((o, p))) w = z q = R.add(w, y) return R.add(q, w) @@ -209,7 +209,7 @@ class Expected: def main(x: R.Tensor(("m", "n"))): y = x o, p = T.var("int64"), T.var("int64") - z = R.match_shape(x, (o, p)) + z = R.match_cast(x, R.Tensor((o, p))) w = z # the shape_ field on q will need to be updated q = R.add(z, x) @@ -219,33 +219,5 @@ def main(x: R.Tensor(("m", "n"))): assert_structural_equal(new_mod, Expected) -def test_unbound_match_shape(): - # ensure that match shapes that do not bind a var are handled correctly - @tvm.script.ir_module - class TestUnboundMatchShape: - @R.function - def main(x: R.Tensor): - y = x - z = y - m, n = T.var("int64"), T.var("int64") - R.match_shape(z, (m, n)) - w = z - return w - - @tvm.script.ir_module - class Expected: - @R.function - def main(x: R.Tensor): - y = x - z = x - m, n = T.var("int64"), T.var("int64") - R.match_shape(x, (m, n)) - w = x - return x - - new_mod = relax.transform.CanonicalizeBindings()(TestUnboundMatchShape) - assert_structural_equal(new_mod, Expected) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 62406c19f9..d642a408de 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -207,7 +207,7 @@ def sub( @R.function def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") - x0 = R.match_shape(x, (n, m)) + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown lv0 = relax.call_tir(addone, (c0,), (n, 16), dtype="float32") # this line can be folded @@ -226,7 +226,7 @@ def expected( x: R.Tensor("float32", ndim=2), ) -> R.Tensor: n, m = T.var("int64"), T.var("int64") - x0 = R.match_shape(x, (n, m)) + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown lv0 = relax.call_tir(addone, (c0,), (n, 16), dtype="float32") # this line can be folded diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 4f6ec1f8ad..9c352a8e42 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -454,6 +454,7 @@ def test_normalize_deeply_nested_seq(): u = relax.Var("u", R.Tensor([], "int32")) v = relax.Var("v", R.Tensor([], "int32")) w = relax.Var("w", R.Tensor([], "int32")) + _ = relax.Var("w", R.Tensor([], "int32")) seq = relax.SeqExpr( [ relax.BindingBlock( @@ -472,9 +473,13 @@ def test_normalize_deeply_nested_seq(): relax.BindingBlock( [ relax.VarBinding(u, relax.const(2)), - relax.MatchShape(u, [], None), + relax.MatchCast( + _, u, R.Tensor([], "int32") + ), relax.VarBinding(v, u), - relax.MatchShape(v, [], w), + relax.MatchCast( + w, v, R.Tensor([], "int32") + ), ] ) ], @@ -504,9 +509,9 @@ def test_normalize_deeply_nested_seq(): def expected(): x = relax.const(1) u = relax.const(2) - R.match_shape(u, ()) + _ = R.match_cast(u, R.Tensor((), "int32")) v = u - w = R.match_shape(v, ()) + w = R.match_cast(v, R.Tensor((), "int32")) z = w y = z return y diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index bd87114394..b154e85b8c 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -53,14 +53,14 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) assert func.body.body.name_hint == "out" -def test_match_shape(): +def test_match_cast(): """ @R.function def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): m = T.var("int64") n = T.var("int64") - R.match_shape(x, (m,)) - y1 = R.match_shape(x, (n,)) + _ = R.match_cast(x, R.Tensor((m,), "float32")) + y1 = R.match_cast(x, R.Tensor((n,), "float32")) return (m, n * 2) """ # create with Script IRBuilder @@ -71,8 +71,8 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): y = R.arg("y", R.tensor(ndim=-1, dtype="float32")) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") - R.emit_match_shape(x, (m,), emit_var=False) - y1 = R.emit_match_shape(y, (n,), emit_var=True) + _ = R.emit_match_cast(x, R.tensor((m,), "float32")) + y1 = R.emit_match_cast(y, R.tensor((n,), "float32")) IRBuilder.name("y1", y1) R.func_ret_value(relax.ShapeExpr([m, n * 2])) func = ir_builder.get() @@ -84,8 +84,8 @@ def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): n = tir.Var("n", dtype="int64") bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - bb.emit_normalized(relax.MatchShape(x, (m,), var=None)) - y1 = bb.match_shape(y, (n,)) + _ = bb.match_cast(x, R.tensor((m,), "float32")) + y1 = bb.match_cast(y, R.tensor((n,), "float32")) bb.emit_func_output(relax.ShapeExpr([m, n * 2])) mod = bb.get() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 5966086835..dcdca379fe 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -188,24 +188,31 @@ def foo(x: R.Tensor((4, 4), "float32")): _check(foo, bb.get()["foo"]) -def test_match_shape(): +def test_match_cast(): @R.function - def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): m = T.var("int64") n = T.var("int64") - R.match_shape(x, (m,)) - y1 = R.match_shape(y, (n,)) - return (m, n * 2) + x0 = R.match_cast(x, R.Tensor([m], "float32")) + with R.dataflow(): + y0 = R.match_cast(y, R.Tensor([n], "float32")) + gv = y0 + R.output(gv) + return (x0, (m, n * 2)) - x = relax.Var("x", R.Tensor("float32", ndim=-1)) - y = relax.Var("y", R.Tensor("float32", ndim=-1)) + x = relax.Var("x", R.Tensor("float32")) + y = relax.Var("y", R.Tensor("float32")) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") + y2 = relax.Var("y", R.Tensor([n], "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): - bb.emit_normalized(relax.MatchShape(x, (m,), var=None)) - y1 = bb.match_shape(y, (n,)) - bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + x0 = bb.match_cast(x, R.Tensor([m], "float32")) + with bb.dataflow(): + y0 = bb.match_cast(y, R.Tensor([n], "float32")) + bb.emit_output(y0) + bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([m, n * 2])])) + _check(foo, bb.get()["foo"]) @@ -230,14 +237,14 @@ def test_tuple_return_2(): @R.function def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") - x0 = R.match_shape(x, (n, m)) + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) return (x0, (n + 1, m, 1)) x = relax.Var("x", R.Tensor("float32", ndim=2)) n, m = tir.Var("n", "int64"), tir.Var("m", "int64") bb = relax.BlockBuilder() with bb.function("foo", (x,)): - x0 = bb.match_shape(x, (n, m)) + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) bb.emit_func_output(relax.Tuple([x0, relax.ShapeExpr([n + 1, m, 1])])) _check(foo, bb.get()["foo"]) @@ -247,7 +254,7 @@ def test_tuple_binding(): @R.function def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") - x0 = R.match_shape(x, (n, m)) + x0 = R.match_cast(x, R.Tensor((n, m), "float32")) t0 = (x, x0) t1 = (x, (n, m), t0) return t1 @@ -256,7 +263,7 @@ def foo(x: R.Tensor("float32", ndim=2)): n, m = tir.Var("n", "int64"), tir.Var("m", "int64") bb = relax.BlockBuilder() with bb.function("foo", (x,)): - x0 = bb.match_shape(x, (n, m)) + x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) t0 = bb.emit(relax.Tuple([x, x0])) t1 = bb.emit(relax.Tuple([x, relax.ShapeExpr([n, m]), t0])) bb.emit_func_output(t1) @@ -295,11 +302,11 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) m = T.var("int64") n = T.var("int64") lv0 = R.call_tir("extern_func", gv1, (128, 128), dtype="float32") - lv1 = R.match_shape(lv0, (m, n)) + lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) gv2 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") gv2 = R.call_tir("extern_func", gv2, (128, 128), dtype="float32") - gv3 = R.match_shape(gv2, (m, n)) - gv3 = R.match_shape(lv0, (m, n)) + gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) + gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) gv4 = gv3 gv5 = gv2 R.output(gv5, gv4) @@ -316,11 +323,11 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) gv1 = bb.emit(relax.call_tir("extern_func", gv0, (128, 128), dtype="float32")) with bb.dataflow(): lv0 = bb.emit(relax.call_tir("extern_func", gv1, (128, 128), dtype="float32")) - lv1 = bb.match_shape(lv0, (m, n)) + lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) gv2 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) gv21 = bb.emit(relax.call_tir("extern_func", gv2, (128, 128), dtype="float32")) - gv3 = bb.match_shape(gv21, (m, n)) - gv31 = bb.match_shape(lv0, (m, n)) + gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) + gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) gv32 = bb.emit_output(gv31) gv22 = bb.emit_output(gv21) gv4 = bb.emit(relax.call_tir("extern_func", gv22, (128, 128), dtype="float32")) @@ -736,7 +743,7 @@ def test_erase_to_well_defined(): def foo(x: R.Tensor): q = x m, n = T.var("int64"), T.var("int64") - z = R.match_shape(q, (m, n)) + z = R.match_cast(q, R.Tensor((m, n))) w = z return w diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 792162a8aa..fd0f6416cc 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -399,7 +399,7 @@ class TestVMCompileStage2: @R.function def foo(x: R.Tensor(dtype="float32")) -> R.Shape: n, m = T.var("int64"), T.var("int64") - R.match_shape(x, (n, m)) + _ = R.match_cast(x, R.Tensor((n, m), "float32")) return (n * 2, m * 3) mod = TestVMCompileStage2 @@ -442,7 +442,7 @@ class TestVMCompileE2E: def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: with R.dataflow(): n, m = T.var("int64"), T.var("int64") - R.match_shape(x, (n, m)) + _ = R.match_cast(x, R.Tensor((n, m), "float32")) y = R.call_tir("test.vm.tile", (x), (n, m * 2), dtype="float32") R.output(y) return y diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index a11673bf69..3dd9c00f95 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -422,13 +422,13 @@ def test_no_match_dtype(): assert not ty_pat.match(x) -def test_match_shape(): +def test_match_cast(): x = relay.var("x", shape=(10, 10), dtype="float32") ty_pat = has_shape((10, 10)) assert ty_pat.match(x) -def test_no_match_shape(): +def test_no_match_cast(): x = relay.var("x", shape=(10, 10), dtype="int32") ty_pat = has_shape((10, 5)) assert not ty_pat.match(x)