Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Refactor ExprVisitor/Mutator to consider Expr in StructInfo.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Dec 23, 2022
1 parent 2e8a3d7 commit 19a6ad0
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 51 deletions.
105 changes: 101 additions & 4 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
#include <tvm/node/functor.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/struct_info_functor.h>
#include <tvm/relay/op.h>
#include <tvm/tir/function.h>

Expand Down Expand Up @@ -244,6 +243,23 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual void VisitVarDef(const Var& var);

/*!
* \brief Visit struct_info may recursively contain Expr/PrimExpr.
*
* By default, this function recurse into struct info such as
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
* accordingly. It does not recurse into FunctionStructInfo as it does
* not contain Expr defined in the current scope.
*
* Pass writers can overload this function to change to other behaviors.
* For example, if we are not interested in Expr in StructInfo, we can
* override this function by a no-op.
*
* \param struct_info Input struct info field.
*/
virtual void VisitExprDepStructInfoField(const StructInfo& struct_info);

// specific leaf level visitor functions
virtual void VisitVarDef_(const VarNode* var);
virtual void VisitVarDef_(const DataflowVarNode* var);
Expand All @@ -258,6 +274,30 @@ class ExprVisitor : public ExprFunctor<void(const Expr&)> {
tvm::NodeFunctor<void(const ObjectRef& n, ExprVisitor* self, const VarBindingNode* binding)>;
// initialize the vtable.
static VisitBindingVTable InitVisitBindingVTable();
/*!
* \brief Private internal struct info field visitor.
*
* Support default visiting of struct info field and recursive into
* their Expr fields.
*
* We use component instead of sub-classing so there can be other
* joint inheritance between ExprVisitor and StructInfoVisitor.
*/
class DefaultStructInfoFieldVisitor : public StructInfoVisitor {
public:
explicit DefaultStructInfoFieldVisitor(ExprVisitor* parent);

// Override defaults in struct info visitor.
void VisitStructInfoExprField(const Expr& expr) final;
void VisitStructInfoExprField(const PrimExpr& expr) final;
void VisitStructInfo_(const FuncStructInfoNode* op) final;

private:
ExprVisitor* parent_;
};
// This visitor is not visible to child classes and only
// used to supportd default visiting behavior.
DefaultStructInfoFieldVisitor default_struct_info_field_visitor_{this};
};

void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
Expand Down Expand Up @@ -309,6 +349,64 @@ class ExprMutatorBase : public ExprFunctor<Expr(const Expr&)> {
* Can be overloaded to transform the shape expressions.
*/
virtual PrimExpr VisitPrimExpr(const PrimExpr& expr);

/*!
* \brief Visit struct_info that may recursively contain Expr/PrimExpr.
*
* By default, this function recurse into struct info such as
* TensorStructInfo and ShapeStructInfo and call VisitExpr/VisitPrimExpr
* accordingly. It does not recurse into FunctionStructInfo as it does
* not contain Expr defined in the current scope.
*
* Pass writers can overload this function to change to other behaviors.
* For example, if in Expr in StructInfo won't change, we can
* override this function by an identity function.
*
* \param struct_info Input struct info field.
* \return The updated struct info.
*/
virtual StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info);

protected:
/*!
* \brief Check whether VisitExprDepStructInfoField change struct_info.
* \return Whether struct info changed.
* \note This function is used by mutator implementations to check if
* previous Expr update will trigger a change in struct_info.
* If change is detected, the implementation can generate a fresh
* node without struct_info, and trigger normalizer to re-derive.
*/
bool VisitAndCheckStructInfoFieldUnchanged(const ObjectRef& struct_info) {
if (const StructInfoNode* sinfo = struct_info.as<StructInfoNode>()) {
return this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)).same_as(struct_info);
} else {
return true;
}
}

private:
/*!
* \brief Private internal struct info field visitor to support
* Default visiting of struct info field and recursive into their Expr fields.
*
* We use component instead of sub-classing so there can be other
* joint inheritance between ExprMutator and StructInfoMutator.
*/
class DefaultStructInfoFieldMutator : public StructInfoMutator {
public:
explicit DefaultStructInfoFieldMutator(ExprMutatorBase* parent);

// Override defaults in struct info visitor.
Expr VisitStructInfoExprField(const Expr& expr) final;
PrimExpr VisitStructInfoExprField(const PrimExpr& expr) final;
StructInfo VisitStructInfo_(const FuncStructInfoNode* op) final;

private:
ExprMutatorBase* parent_;
};
// This visitor is not visible to child classes and only
// used to supportd default visiting behavior.
DefaultStructInfoFieldMutator default_struct_info_field_mutator_{this};
};

/*!
Expand All @@ -324,7 +422,6 @@ class ExprMutator : public ExprMutatorBase {

ExprMutator(Optional<IRModule> mod = NullOpt) { builder_ = BlockBuilder::Create(mod); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const DataflowVarNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Var,
ShapeExpr,
GlobalVar,
PrimExpr,
BindingBlock,
Tuple,
BaseFunc,
Expand Down
Loading

0 comments on commit 19a6ad0

Please sign in to comment.