Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[REFACTOR] StructInfo M3: MatchShape=>MatchCast #323

Merged
merged 5 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ class BlockBuilderNode : public Object {
virtual Var Emit(Expr expr, String name_hint = "") = 0;

/*!
* \brief Emit a MatchShape.
* \param value The value of the MatchShape to be emitted.
* \param pattern The pattern of the MatchShape to be emitted.
* \brief Emit a MatchCast.
* \param value The input value.
* \param struct_info The struct info to be matched.
* \param name_hint Name hint for the bound variable.
* \return The variable bound to the MatchShape.
* \return The variable bound to the MatchCast.
*/
virtual Var EmitMatchShape(Expr value, Array<PrimExpr> pattern, String name_hint = "") = 0;
virtual Var EmitMatchCast(Expr value, StructInfo struct_info, String name_hint = "") = 0;

/*!
* \brief Generate an output for the current dataflow block.
Expand Down
58 changes: 32 additions & 26 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,10 @@ class Constant : public Expr {
/*! \brief The base class of a variable binding in Relax. */
class BindingNode : public Object {
public:
/*! \brief The return variable to bound to. */
Var var;
mutable Span span;

void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; }
void SHashReduce(SHashReducer hash_reduce) const {}

static constexpr const char* _type_key = "relax.expr.Binding";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand All @@ -555,51 +553,61 @@ class Binding : public ObjectRef {
using ContainerType = BindingNode;
};

/*! \brief Symbolic shape match, binds the variable of the lhs with the rhs. */
class MatchShape;
class MatchShapeNode : public BindingNode {
/*!
* \brief Runtime-match the value to the struct info.
*
* This operation does runtime check, populates the un-defined symbolic shape vars
* and vars in struct_info in first occurance, and insert equality assertions in
* other cases.
*/
class MatchCastNode : public BindingNode {
public:
/*! \brief The input value to match cast. */
Expr value;
Array<PrimExpr> pattern;
Var var;
/*! \brief The struct info pattern to match to. */
StructInfo struct_info;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("pattern", &pattern);
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("struct_info", &struct_info);
v->Visit("span", &span);
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
return equal(value, other->value) && equal.DefEqual(pattern, other->pattern) &&
equal.DefEqual(var, other->var);
return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) &&
equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
hash_reduce(value);
hash_reduce.DefHash(pattern);
hash_reduce.DefHash(var);
hash_reduce.DefHash(struct_info);
hash_reduce(value);
}

static constexpr const char* _type_key = "relax.expr.MatchShape";
static constexpr const char* _type_key = "relax.expr.MatchCast";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(MatchShapeNode, BindingNode);
TVM_DECLARE_FINAL_OBJECT_INFO(MatchCastNode, BindingNode);
};

class MatchShape : public Binding {
/*!
* \brief Managed reference to MatchCastNode.
* \sa MatchCastNode
*/
class MatchCast : public Binding {
public:
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern, Var var, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode);
TVM_DLL explicit MatchCast(Var var, Expr value, StructInfo struct_info, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(MatchCast, Binding, MatchCastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchCastNode);
};

class VarBinding;
class VarBindingNode : public BindingNode {
public:
Var var;
/*! \brief The binding value. */
Expr value;

void VisitAttrs(AttrVisitor* v) {
Expand Down Expand Up @@ -628,8 +636,6 @@ class VarBinding : public Binding {
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode);
};

class BindingBlock;

class BindingBlockNode : public Object {
public:
mutable Span span;
Expand Down
139 changes: 118 additions & 21 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
#include <tvm/node/functor.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/struct_info_functor.h>
#include <tvm/relay/op.h>
#include <tvm/tir/function.h>

Expand Down Expand Up @@ -213,7 +212,7 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
virtual void VisitBinding(const Binding& binding);
// specific leaf level visitor functions
virtual void VisitBinding_(const VarBindingNode* binding);
virtual void VisitBinding_(const MatchShapeNode* binding);
virtual void VisitBinding_(const MatchCastNode* binding);
// second level dispatching based on binding value type.
// these dispatching functions get called from first-level dispatch on VarBinding
virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
Expand Down Expand Up @@ -244,6 +243,23 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual void VisitVarDef(const Var& var);

/*!
* \brief Visit struct_info may recursively contain Expr/PrimExpr.
*
* By default, this function recurse into struct info such as
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
* accordingly. It does not recurse into FunctionStructInfo as it does
* not contain Expr defined in the current scope.
*
* Pass writers can overload this function to change to other behaviors.
* For example, if we are not interested in Expr in StructInfo, we can
* override this function by a no-op.
*
* \param struct_info Input struct info field.
*/
virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);

// specific leaf level visitor functions
virtual void VisitVarDef_(const VarNode* var);
virtual void VisitVarDef_(const DataflowVarNode* var);
Expand All @@ -258,6 +274,30 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
// initialize the vtable.
static VisitBindingVTable InitVisitBindingVTable();
/*!
* \brief Private internal struct info field visitor.
*
* Support default visiting of struct info field and recursive into
* their Expr fields.
*
* We use component instead of sub-classing so there can be other
* joint inheritance between ExprVisitor and StructInfoVisitor.
*/
class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
public:
explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);

// Override defaults in struct info visitor.
void VisitStructInfoExprField(const Expr& expr) final;
void VisitStructInfoExprField(const PrimExpr& expr) final;
void VisitStructInfo_(const FuncStructInfoNode* op) final;

private:
ExprVisitor* parent_;
};
// This visitor is not visible to child classes and only
// used to supportd default visiting behavior.
DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand Down Expand Up @@ -309,6 +349,64 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
* Can be overloaded to transform the shape expressions.
*/
virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);

/*!
* \brief Visit struct_info that may recursively contain Expr/PrimExpr.
*
* By default, this function recurse into struct info such as
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
* accordingly. It does not recurse into FunctionStructInfo as it does
* not contain Expr defined in the current scope.
*
* Pass writers can overload this function to change to other behaviors.
* For example, if in Expr in StructInfo won't change, we can
* override this function by an identity function.
*
* \param struct_info Input struct info field.
* \return The updated struct info.
*/
virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);

protected:
/*!
* \brief Check whether VisitExprDepStructInfoField change struct_info.
* \return Whether struct info changed.
* \note This function is used by mutator implementations to check if
* previous Expr update will trigger a change in struct_info.
* If change is detected, the implementation can generate a fresh
* node without struct_info, and trigger normalizer to re-derive.
*/
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) {
if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
return this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)).same_as(struct_info);
} else {
return true;
}
}

private:
/*!
* \brief Private internal struct info field visitor to support
* Default visiting of struct info field and recursive into their Expr fields.
*
* We use component instead of sub-classing so there can be other
* joint inheritance between ExprMutator and StructInfoMutator.
*/
class DefaultStructInfoFieldMutator : public StructInfoMutator {
public:
explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);

// Override defaults in struct info visitor.
Expr VisitStructInfoExprField(const Expr& expr) final;
PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;

private:
ExprMutatorBase* parent_;
};
// This visitor is not visible to child classes and only
// used to supportd default visiting behavior.
DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
};

/*!
Expand All @@ -324,7 +422,6 @@ class ExprMutator : public ExprMutatorBase {

ExprMutator(Optional<IRModule> mod = NullOpt) { builder_ = BlockBuilder::Create(mod); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const DataflowVarNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expand All @@ -338,7 +435,7 @@ class ExprMutator : public ExprMutatorBase {
virtual void VisitBinding(const Binding& binding);
// specific leaf level visitor functions
virtual void VisitBinding_(const VarBindingNode* binding);
virtual void VisitBinding_(const MatchShapeNode* binding);
virtual void VisitBinding_(const MatchCastNode* binding);
// second level dispatching based on binding value type.
// these dispatching functions get called from first-level dispatch on VarBinding
virtual void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val);
Expand Down Expand Up @@ -484,9 +581,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
/*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
* function. */
PackedFunc f_visit_var_binding_{nullptr};
/*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)`
/*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)`
* function. */
PackedFunc f_visit_match_shape_{nullptr};
PackedFunc f_visit_match_cast_{nullptr};
/*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
* function. */
PackedFunc f_visit_binding_block{nullptr};
Expand Down Expand Up @@ -523,8 +620,8 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
void VisitBinding_(const VarBindingNode* binding)
PY_EXPR_VISITOR_DEFAULT(GetRef<VarBinding>(binding), f_visit_var_binding_,
ExprVisitor::VisitBinding_(binding));
void VisitBinding_(const MatchShapeNode* binding)
PY_EXPR_VISITOR_DEFAULT(GetRef<MatchShape>(binding), f_visit_match_shape_,
void VisitBinding_(const MatchCastNode* binding)
PY_EXPR_VISITOR_DEFAULT(GetRef<MatchCast>(binding), f_visit_match_cast_,
ExprVisitor::VisitBinding_(binding));

void VisitBindingBlock(const BindingBlock& block)
Expand Down Expand Up @@ -602,7 +699,7 @@ class PyExprVisitor : public ObjectRef {
* \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
* \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
* binding)`.
* \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode*
* \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode*
* binding)`.
* \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
* block)`.
Expand All @@ -624,7 +721,7 @@ class PyExprVisitor : public ObjectRef {
PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_,
PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_,
PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_,
PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
Expand All @@ -649,7 +746,7 @@ class PyExprVisitor : public ObjectRef {
n->f_visit_op_ = f_visit_op_;
n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
n->f_visit_var_binding_ = f_visit_var_binding_;
n->f_visit_match_shape_ = f_visit_match_shape_;
n->f_visit_match_cast_ = f_visit_match_cast_;
n->f_visit_binding_block_ = f_visit_binding_block_;
n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
n->f_visit_var_def_ = f_visit_var_def_;
Expand Down Expand Up @@ -702,9 +799,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
/*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)`
* function. */
PackedFunc f_visit_var_binding_{nullptr};
/*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)`
/*! \brief The packed function to the `VisitBinding_(const MatchCastNode* binding)`
* function. */
PackedFunc f_visit_match_shape_{nullptr};
PackedFunc f_visit_match_cast_{nullptr};
/*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)`
* function. */
PackedFunc f_visit_binding_block{nullptr};
Expand Down Expand Up @@ -748,9 +845,9 @@ class PyExprMutatorNode : public Object, public ExprMutator {
ExprMutator::VisitBinding_(binding);
}

void VisitBinding_(const MatchShapeNode* binding) {
if (f_visit_match_shape_ != nullptr)
f_visit_match_shape_(GetRef<MatchShape>(binding));
void VisitBinding_(const MatchCastNode* binding) {
if (f_visit_match_cast_ != nullptr)
f_visit_match_cast_(GetRef<MatchCast>(binding));
else
ExprMutator::VisitBinding_(binding);
}
Expand Down Expand Up @@ -866,7 +963,7 @@ class PyExprMutator : public ObjectRef {
* \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`.
* \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode*
* binding)`.
* \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode*
* \param f_visit_match_cast_ The packed function of `VisitBinding_(const MatchCastNode*
* binding)`.
* \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock&
* block)`.
Expand All @@ -889,7 +986,7 @@ class PyExprMutator : public ObjectRef {
PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_,
PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_,
PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_,
PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_cast_,
PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_,
PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_,
PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) {
Expand All @@ -911,7 +1008,7 @@ class PyExprMutator : public ObjectRef {
n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_;
n->f_visit_binding = f_visit_binding;
n->f_visit_var_binding_ = f_visit_var_binding_;
n->f_visit_match_shape_ = f_visit_match_shape_;
n->f_visit_match_cast_ = f_visit_match_cast_;
n->f_visit_binding_block = f_visit_binding_block;
n->f_visit_binding_block_ = f_visit_binding_block_;
n->f_visit_dataflow_block_ = f_visit_dataflow_block_;
Expand Down
Loading