Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[TVMScript] B5-6: TIR IRBuilder #231

Merged
merged 6 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,17 @@ class PrimFuncFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
};

/*!
* \brief A frame that represents the assert statement. Proceeds if the condition is true,
* otherwise aborts with the message.
*
* \sa AssertFrame
*/
class AssertFrameNode : public TIRFrameNode {
public:
/*! \brief The PrimExpr to test. */
PrimExpr condition;
/*! \brief The output error message when the assertion failed. */
PrimExpr message;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -186,17 +194,33 @@ class AssertFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AssertFrameNode.
*
* \sa AssertFrameNode
*/
class AssertFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
};

/*!
* \brief A frame represents the let binding expression, which binds a var.
*
* \sa LetFrameNode
*/
class LetFrameNode : public TIRFrameNode {
public:
/*! \brief The variable we bind to */
tvm::tir::Var var;
/*! \brief The value we bind var to */
PrimExpr value;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -209,21 +233,41 @@ class LetFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LetFrameNode.
*
* \sa LetFrameNode
*/
class LetFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
};

/*!
* \brief A frame represents the allocate.
*
* \sa AllocateFrame
*/
class AllocateFrameNode : public TIRFrameNode {
public:
/*! \brief The extents of the allocate. */
Array<PrimExpr> extents;
/*! \brief The data type of the buffer. */
DataType dtype;
/*! \brief The storage scope. */
String storage_scope;
/*! \brief The condition. */
PrimExpr condition;
/*! \brief Additional annotation hints. */
Map<String, ObjectRef> annotations;
/*! \brief The buffer. */
tvm::tir::Buffer buffer;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -240,20 +284,39 @@ class AllocateFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AllocateFrameNode.
*
* \sa AllocateFrameNode
*/
class AllocateFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
};

/*!
* \brief A frame represents the allocate constant.
*
* \sa AllocateConstFrame
*/
class AllocateConstFrameNode : public TIRFrameNode {
public:
/*! \brief The data type of the buffer. */
DataType dtype;
/*! \brief The extents of the allocate. */
Array<PrimExpr> extents;
/*! \brief The data associated with the constant. */
tvm::runtime::NDArray data;
/*! \brief The buffer */
tvm::tir::Buffer buffer;
/*! \brief Additional annotations about the allocation. */
Map<String, ObjectRef> annotations;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -269,19 +332,35 @@ class AllocateConstFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AllocateConstFrameNode.
*
* \sa AllocateConstFrameNode
*/
class AllocateConstFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
AllocateConstFrameNode);
};

/*!
* \brief The LaunchThreadFrameNode.
* \note It is used only inside a PrimFunc.
*/
class LaunchThreadFrameNode : public TIRFrameNode {
public:
/*! \brief The extent of environment thread. */
PrimExpr extent;
/*! \brief The attribute key, could be either virtual_thread or thread_extent. */
String attr_key;
/*! \brief The iteration variable. */
tvm::tir::IterVar iter_var;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -295,19 +374,36 @@ class LaunchThreadFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LaunchThreadFrameNode.
*
* \sa LaunchThreadFrameNode
*/
class LaunchThreadFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
LaunchThreadFrameNode);
};

/*!
* \brief A frame that represents realization.
*
* \sa RealizeFrame
*/
class RealizeFrameNode : public TIRFrameNode {
public:
/*! \brief The region of buffer access. */
tvm::tir::BufferRegion buffer_slice;
/*! \brief The storage scope associated with this realization. */
String storage_scope;
/*! \brief The condition expression. */
PrimExpr condition;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -321,18 +417,35 @@ class RealizeFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to RealizeFrameNode.
*
* \sa RealizeFrameNode
*/
class RealizeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
};

/*!
* \brief A frame that represents attribute node.
*
* \sa AttrFrame
*/
class AttrFrameNode : public TIRFrameNode {
public:
/*! \brief The node to annotate the attribute. */
ObjectRef node;
/*! \brief Attribute type key. */
String attr_key;
/*! \brief The value of the attribute. */
PrimExpr value;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -346,16 +459,31 @@ class AttrFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AttrFrameNode.
*
* \sa AttrFrameNode
*/
class AttrFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
};

/*!
* \brief A frame that represents while loop.
*
* \sa WhileFrame
*/
class WhileFrameNode : public TIRFrameNode {
public:
/*! \brief The termination condition of while. */
PrimExpr condition;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -367,18 +495,35 @@ class WhileFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to WhileFrameNode.
*
* \sa WhileFrameNode
*/
class WhileFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
};

/*!
* \brief A frame that represents if statement.
*
* \sa IfFrame
*/
class IfFrameNode : public TIRFrameNode {
public:
/*! \brief The condition of the if statement. */
PrimExpr condition;
/*! \brief The statements in the true branch. */
Optional<Array<tvm::tir::Stmt>> then_stmts;
/*! \brief The stetements in the false branch. */
Optional<Array<tvm::tir::Stmt>> else_stmts;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand All @@ -392,39 +537,84 @@ class IfFrameNode : public TIRFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to IfFrameNode.
*
* \sa IfFrameNode
*/
class IfFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
};

/*!
* \brief A frame that represents then.
*
* \sa ThenFrame
*/
class ThenFrameNode : public TIRFrameNode {
public:
static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when entering RAII scope.
* \sa tvm::support::With
*/
void EnterWithScope() final;
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to ThenFrameNode.
*
* \sa ThenFrameNode
*/
class ThenFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
};

/*!
* \brief A frame that represents else.
*
* \sa ElseFrame
*/
class ElseFrameNode : public TIRFrameNode {
public:
static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when entering RAII scope.
* \sa tvm::support::With
*/
void EnterWithScope() final;
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to ElseFrameNode.
*
* \sa ElseFrameNode
*/
class ElseFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
Expand Down
Loading