Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

ExprMutator refactor & Normalizer #32

Merged
merged 14 commits into from
Nov 10, 2021
1 change: 1 addition & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class GlobalVar : public RelayExpr {
TVM_DLL explicit GlobalVar(String name_hint);

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode);
altanh marked this conversation as resolved.
Show resolved Hide resolved
};

// PrimExprs that are useful as runtime containers.
Expand Down
61 changes: 46 additions & 15 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,39 +44,43 @@ class BlockBuilder;
*/
class BlockBuilderNode : public Object {
public:
BlockBuilderNode(std::shared_ptr<NameTable> name_table) : name_table_(name_table) {}
BlockBuilderNode();

~BlockBuilderNode();

BlockBuilderNode() { name_table_ = std::make_shared<NameTable>(); }

/*! \brief Begin to build a DataflowBlock. */
void BeginDataflowBlock();

YuchenJin marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief Begin to build a BindingBlock. */
void BeginBindingBlock();

/*!
* \brief End building a BindingBlock.
* \return The BindingBlock being built.
*/
BindingBlock EndBlock();

/*!
* \brief Check if the block being built is DataflowBlock or not.
* \return A boolean that indicates if the block being built is DataflowBlock or not.
*/
inline bool CurrentBlockIsDataFlow() { return CurrentFrame()->is_dataflow; }

/*!
* \brief Emits an Expr, and returns the variable it is bound to.
* \param expr The Expr to be emitted.
* \param name_hint Name hint for the bound variable.
* \return The new variable that \p expr is bound to.
*/
virtual Var Emit(const Expr& expr, std::string name_hint = "");

/*!
* \brief Emits a variable binding, and returns the bound Var.
* \param binding The variable binding.
* \return The bound variable.
*/
virtual Var Emit(const VarBinding& binding);

/*!
* \brief Emit a MatchShape.
* \param value The value of the MatchShape to be emitted.
Expand All @@ -85,49 +89,57 @@ class BlockBuilderNode : public Object {
* \return The variable bound to the MatchShape.
*/
Var EmitMatchShape(const Expr& value, const Array<PrimExpr>& pattern, std::string name_hint = "");

/*!
* \brief Emit a MatchShape binding.
* \param binding The MatchShape binding to be emitted.
* \return The variable bound to the MatchShape.
*/
Var EmitMatchShape(const MatchShape& binding);

/*!
* \brief Generate an output for the current dataflow block.
* \param output The output variable of the block.
* \param name_hint Name hint for the bound variable.
* \return The variable bound to \p output.
*/
Var EmitOutput(const Expr& output, std::string name_hint = "");

/*!
* \brief Generate an output for the current dataflow block.
* \param binding The output binding to output.
* \return The variable bound to \p output.
*/
Var EmitOutput(const VarBinding& binding);

/*!
* \brief Lookup a var in the binding table \p var_map_.
* \brief Lookup a var in the binding table \p binding_table_.
* \param var The input var.
* \return The Expr bound to the input \p var.
*/
Expr LookupVar(const Var& var);
Expr LookupBinding(const Var& var);

/*!
* \brief Check if two shape expressions can be proven equal at compile time.
* \param lhs The input lhs shape.
* \param rhs The input rhs shape.
* \return Whether we can prove lhs shape is the same as the rhs shape.
*/
bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs);

/*!
* \brief Normalize an Expr to complete its shape and type.
* \param expr The input expr.
* \return The expr with normalized shape and type.
* \brief Convert an expression to A-normal form, and try to eagerly infer types and shapes.
* \param expr The input expression.
* \return The normalized expression.
*/
Expr Normalize(const Expr& expr);

/*!
* \brief Create a BlockBuilder.
* \return The created BlockBuilder.
* \brief Get the name table for generating unique names.
*
* \return The name table.
*/
TVM_DLL static BlockBuilder Create();
NameTable* name_table();

void VisitAttrs(AttrVisitor* v) {}

Expand All @@ -150,26 +162,45 @@ class BlockBuilderNode : public Object {
Array<Binding> bindings;
bool is_dataflow;
};

/*!
* \brief Utility class for performing IR normalization (conversion to ANF, eager forward shape
* and type inference).
*/
class ExprNormalizer;

friend class BlockBuilder;

/*!
* \brief Get the current block frame.
* \return The current block frame.
*/
BlockFrame* CurrentFrame();

/*! \brief A stack to store block frames. */
std::stack<BlockFrame> block_stack_;

/*! \brief A diagnostic context for reporting errors. */
DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {}));

/*! \brief A binding table that maps var to value. */
// TODO(@yuchen, @altanh): make var_map_ scoped, and decide if it should be in the builder
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> binding_table_;

/*! \brief A name table to get unique names for IR construction. */
std::shared_ptr<NameTable> name_table_;
std::unique_ptr<NameTable> name_table_;

/*! \brief The internal normalizer used for ANF conversion. */
std::unique_ptr<ExprNormalizer> normalizer_;
};

class BlockBuilder : public ObjectRef {
public:
TVM_DLL explicit BlockBuilder(std::shared_ptr<NameTable> name_table);
/*!
* \brief Create a BlockBuilder.
* \return The created BlockBuilder.
*/
TVM_DLL static BlockBuilder Create();

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode);
};

Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ShapeExpr : public Expr {
public:
TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShapeExprNode);
};

/*! \brief The variable class for all Relax bindings. */
Expand Down Expand Up @@ -131,6 +132,7 @@ class Var : public Expr {
TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};

/*! \brief A sub-type of the variable node used to mark dataflow variables from
Expand Down Expand Up @@ -175,6 +177,7 @@ class DataflowVar : public Var {
runtime::Optional<Type> type_annotation, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowVarNode);
};

/*! \brief The base class of a variable binding in Relax. */
Expand Down Expand Up @@ -235,6 +238,7 @@ class MatchShape : public Binding {
TVM_DLL explicit MatchShape(Expr value, Array<PrimExpr> pattern,
Var var, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchShapeNode);
};

class VarBinding;
Expand Down Expand Up @@ -266,6 +270,7 @@ class VarBinding : public Binding {
public:
TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarBindingNode);
};

class BindingBlock;
Expand Down Expand Up @@ -296,6 +301,7 @@ class BindingBlock : public ObjectRef {
public:
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode);
};

class DataflowBlock;
Expand All @@ -315,6 +321,7 @@ class DataflowBlock : public BindingBlock {
public:
TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DataflowBlockNode);
};

/*! \brief A sequence of blocks followed by an expression.
Expand Down Expand Up @@ -356,6 +363,7 @@ class SeqExpr : public Expr {
public:
TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqExprNode);
};

/*! \brief A Relax function, eventually to replace the current Relay function definition. */
Expand Down Expand Up @@ -411,6 +419,7 @@ class Function : public Expr {
TVM_DLL explicit Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
Type ret_type, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};

/*! \brief The extern function, which can represent packed function. */
Expand Down Expand Up @@ -440,6 +449,7 @@ class ExternFunc : public Expr {
public:
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ExternFuncNode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for future iterations, perhaps it is worth to remove COW because the shape changing pattern issue

};

} // namespace relax
Expand Down
82 changes: 43 additions & 39 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,7 @@ void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
public:
ExprMutator() {
name_table_ = std::make_shared<NameTable>();
builder_ = BlockBuilder(name_table_);
}

/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) {
return this->VisitExpr(expr);
builder_ = BlockBuilder::Create();
}

Expr VisitExpr(const Expr& expr) override;
Expand Down Expand Up @@ -218,47 +209,60 @@ class ExprMutator : public ExprFunctor<Expr(const Expr&)> {
virtual void VisitVarBinding(const VarBinding& binding);
virtual void VisitMatchShape(const MatchShape& binding);

/*!
* \brief Rewrite the var definition site.
* \param var The var to be visited.
* \return The var after post-order rewritten.
altanh marked this conversation as resolved.
Show resolved Hide resolved
* \note VisitExpr_(const VarNode*) will only visit the usage site of an Var
*/
virtual Var VisitVarDef(const Var& var);

virtual BindingBlock VisitBindingBlock(const BindingBlock& block);
virtual BindingBlock VisitDataflowBlock(const DataflowBlock& block);

protected:
Expr MutateWithPrologue(const Expr& expr, bool is_dataflow);
class ExprNormalizer;

/*! \brief Look up the value of a variable. If the variable is bound, then returns the bound
* value. Otherwise, returns the rewritten expression for the variable.
/*!
* \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If.
* \param expr The expr to be visited.
* \return The expr after visiting.
*/
Expr LookupVar(Var var);
Expr VisitWithNewScope(const Expr& expr);

inline void UpdateMemo(Expr pre, Expr post) {
if (const VarNode* var = pre.as<VarNode>()) {
var_memo_[var->vid] = post;
} else {
expr_memo_[pre] = post;
}
}
/*!
* \brief Look up the value bound to a variable.
* \param var The var to be looked up.
* \return The value bound to the input \p var.
*/
Expr LookupBinding(const Var& var);

inline Optional<Expr> LookupMemo(Expr pre) {
if (pre.as<VarNode>()) {
Id vid = Downcast<Var>(pre)->vid;
if (var_memo_.count(vid)) {
return var_memo_[vid];
}
} else {
if (expr_memo_.count(pre)) {
return expr_memo_[pre];
}
}
return NullOpt;
/*!
* \brief Post-order rewrite a node and normalize.
* \param T The node type to be rewritten.
* \param op The node to be rewritten.
* \return The node after post rewritten.
*/
template <typename T>
Expr VisitExprPostOrder_(const T* op) {
return builder_->Normalize(ExprMutator::VisitExpr_(op));
}

/*! \brief Variable memoization table using Id equality */
std::unordered_map<Id, Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;

/*! \brief Expr memoization table using pointer equality */
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> expr_memo_;
/*!
* \brief Create a new var with specified shape and type if it's original shape or type does not
* match with the specified ones.
* \param var The var to be updated.
* \param shape The specified shape.
* \param type The specified type.
* \return The var filled with \p shape and \p type.
*/
Var WithShapeAndType(Var var, Optional<ObjectRef> shape, Type type);

std::shared_ptr<NameTable> name_table_;
/*! \brief Internal block builder to emit bindings during rewriting. */
BlockBuilder builder_;

/*! \brief Remap a var to a new var in use-site. */
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
};

// TODO(@yuchen, @altan): Refactor to enforce dataflow mutator only rewrite stuff in dataflow blocks
Expand Down
Loading