diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index ffb225c512cd..1c470fae51ee 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -45,6 +45,12 @@ using tvm::runtime::String; */ class BaseExprNode : public Object { public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + static constexpr const char* _type_key = "BaseExpr"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -135,11 +141,6 @@ class PrimExpr : public BaseExpr { */ class RelayExprNode : public BaseExprNode { public: - /*! - * \brief Span that points to the original source code. - * Reserved debug information. - */ - mutable Span span; /*! * \brief Stores the result of type inference(type checking). * @@ -263,8 +264,9 @@ class IntImm : public PrimExpr { * \brief Constructor. * \param dtype The data type of the value. * \param value The internal value. + * \param span The location of this object in the source code. */ - TVM_DLL IntImm(DataType dtype, int64_t value); + TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); }; @@ -307,8 +309,9 @@ class FloatImm : public PrimExpr { * \brief Constructor. * \param dtype The data type of the value. * \param value The internal value. + * \param span The location in the source code. */ - TVM_DLL FloatImm(DataType dtype, double value); + TVM_DLL FloatImm(DataType dtype, double value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); }; @@ -321,7 +324,7 @@ class FloatImm : public PrimExpr { */ class Bool : public IntImm { public: - explicit Bool(bool value) : IntImm(DataType::Bool(), value) {} + explicit Bool(bool value, Span span = Span()) : IntImm(DataType::Bool(), value, span) {} Bool operator!() const { return Bool((*this)->value == 0); } operator bool() const { return (*this)->value != 0; } @@ -358,7 +361,7 @@ class Integer : public IntImm { /*! * \brief Construct integer from int value. */ - Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) + Integer(int value, Span span = Span()) : IntImm(DataType::Int(32), value, span) {} // NOLINT(*) /*! * \brief Construct integer from int imm. * \param other The other value. diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index e150ff38041b..69741bbdca62 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -77,6 +77,11 @@ class BufferNode : public Object { int offset_factor; /*! \brief buffer type */ BufferType buffer_type; + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; /*! \brief constructor */ BufferNode() {} @@ -135,7 +140,7 @@ class Buffer : public ObjectRef { // A default value will be picked. TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type); + int offset_factor, BufferType buffer_type, Span span = Span()); /*! * \brief Return a new buffer that is equivalent with current one @@ -183,11 +188,12 @@ class Buffer : public ObjectRef { * \param shape The shape of the buffer, * \param dtype The content data type. * \param name The name of the buffer + * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer"); + String name = "buffer", Span span = Span()); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index eee0deecdc70..f2ae58554ab1 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -74,7 +74,7 @@ class StringImmNode : public PrimExprNode { */ class StringImm : public PrimExpr { public: - TVM_DLL StringImm(String value); + TVM_DLL StringImm(String value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); }; @@ -111,7 +111,7 @@ class CastNode : public PrimExprNode { */ class Cast : public PrimExpr { public: - TVM_DLL Cast(DataType dtype, PrimExpr value); + TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); }; @@ -158,7 +158,7 @@ class AddNode : public BinaryOpNode { */ class Add : public PrimExpr { public: - TVM_DLL Add(PrimExpr a, PrimExpr b); + TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); }; @@ -174,7 +174,7 @@ class SubNode : public BinaryOpNode { */ class Sub : public PrimExpr { public: - TVM_DLL Sub(PrimExpr a, PrimExpr b); + TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); }; @@ -190,7 +190,7 @@ class MulNode : public BinaryOpNode { */ class Mul : public PrimExpr { public: - TVM_DLL Mul(PrimExpr a, PrimExpr b); + TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); }; @@ -209,7 +209,7 @@ class DivNode : public BinaryOpNode { */ class Div : public PrimExpr { public: - TVM_DLL Div(PrimExpr a, PrimExpr b); + TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); }; @@ -228,7 +228,7 @@ class ModNode : public BinaryOpNode { */ class Mod : public PrimExpr { public: - TVM_DLL Mod(PrimExpr a, PrimExpr b); + TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); }; @@ -244,7 +244,7 @@ class FloorDivNode : public BinaryOpNode { */ class FloorDiv : public PrimExpr { public: - TVM_DLL FloorDiv(PrimExpr a, PrimExpr b); + TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); }; @@ -260,7 +260,7 @@ class FloorModNode : public BinaryOpNode { */ class FloorMod : public PrimExpr { public: - TVM_DLL FloorMod(PrimExpr a, PrimExpr b); + TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); }; @@ -276,7 +276,7 @@ class MinNode : public BinaryOpNode { */ class Min : public PrimExpr { public: - TVM_DLL Min(PrimExpr a, PrimExpr b); + TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); }; @@ -292,7 +292,7 @@ class MaxNode : public BinaryOpNode { */ class Max : public PrimExpr { public: - TVM_DLL Max(PrimExpr a, PrimExpr b); + TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); }; @@ -339,7 +339,7 @@ class EQNode : public CmpOpNode { */ class EQ : public PrimExpr { public: - TVM_DLL EQ(PrimExpr a, PrimExpr b); + TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); }; @@ -355,7 +355,7 @@ class NENode : public CmpOpNode { */ class NE : public PrimExpr { public: - TVM_DLL NE(PrimExpr a, PrimExpr b); + TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); }; @@ -371,7 +371,7 @@ class LTNode : public CmpOpNode { */ class LT : public PrimExpr { public: - TVM_DLL LT(PrimExpr a, PrimExpr b); + TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); }; @@ -387,7 +387,7 @@ struct LENode : public CmpOpNode { */ class LE : public PrimExpr { public: - TVM_DLL LE(PrimExpr a, PrimExpr b); + TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); }; @@ -403,7 +403,7 @@ class GTNode : public CmpOpNode { */ class GT : public PrimExpr { public: - TVM_DLL GT(PrimExpr a, PrimExpr b); + TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); }; @@ -419,7 +419,7 @@ class GENode : public CmpOpNode { */ class GE : public PrimExpr { public: - TVM_DLL GE(PrimExpr a, PrimExpr b); + TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); }; @@ -457,7 +457,7 @@ class AndNode : public PrimExprNode { */ class And : public PrimExpr { public: - TVM_DLL And(PrimExpr a, PrimExpr b); + TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); }; @@ -495,7 +495,7 @@ class OrNode : public PrimExprNode { */ class Or : public PrimExpr { public: - TVM_DLL Or(PrimExpr a, PrimExpr b); + TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); }; @@ -529,7 +529,7 @@ class NotNode : public PrimExprNode { */ class Not : public PrimExpr { public: - TVM_DLL Not(PrimExpr a); + TVM_DLL Not(PrimExpr a, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); }; @@ -578,7 +578,7 @@ class SelectNode : public PrimExprNode { */ class Select : public PrimExpr { public: - TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); + TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); }; @@ -627,7 +627,7 @@ class BufferLoadNode : public PrimExprNode { */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, Array indices); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); }; @@ -674,7 +674,7 @@ class ProducerLoadNode : public PrimExprNode { */ class ProducerLoad : public PrimExpr { public: - TVM_DLL explicit ProducerLoad(DataProducer producer, Array indices); + TVM_DLL explicit ProducerLoad(DataProducer producer, Array indices, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); }; @@ -732,7 +732,8 @@ class LoadNode : public PrimExprNode { */ class Load : public PrimExpr { public: - TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); + TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode); }; @@ -783,7 +784,7 @@ class RampNode : public PrimExprNode { */ class Ramp : public PrimExpr { public: - TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes); + TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); }; @@ -821,7 +822,7 @@ class BroadcastNode : public PrimExprNode { */ class Broadcast : public PrimExpr { public: - TVM_DLL Broadcast(PrimExpr value, int lanes); + TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); }; @@ -866,7 +867,7 @@ class LetNode : public PrimExprNode { */ class Let : public PrimExpr { public: - TVM_DLL Let(Var var, PrimExpr value, PrimExpr body); + TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); }; @@ -911,7 +912,7 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - TVM_DLL Call(DataType dtype, RelayExpr op, Array args); + TVM_DLL Call(DataType dtype, RelayExpr op, Array args, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); }; @@ -953,9 +954,9 @@ class ShuffleNode : public PrimExprNode { */ class Shuffle : public PrimExpr { public: - TVM_DLL Shuffle(Array vectors, Array indices); - TVM_DLL static PrimExpr Concat(Array vectors); - TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index); + TVM_DLL Shuffle(Array vectors, Array indices, Span span = Span()); + TVM_DLL static PrimExpr Concat(Array vectors, Span span = Span()); + TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); }; @@ -981,6 +982,11 @@ class CommReducerNode : public Object { Array identity_element; /*! \brief Function call operator to combine a and b */ Array operator()(Array a, Array b) const; + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; void VisitAttrs(AttrVisitor* v) { v->Visit("lhs", &lhs); @@ -1014,7 +1020,7 @@ class CommReducerNode : public Object { class CommReducer : public ObjectRef { public: TVM_DLL CommReducer(Array lhs, Array rhs, Array result, - Array identity_element); + Array identity_element, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); }; @@ -1077,7 +1083,7 @@ class ReduceNode : public PrimExprNode { class Reduce : public PrimExpr { public: TVM_DLL Reduce(CommReducer combiner, Array src, Array rdom, PrimExpr condition, - int value_index, Array init); + int value_index, Array init, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); }; @@ -1106,7 +1112,7 @@ class AnyNode : public PrimExprNode { */ class Any : public PrimExpr { public: - TVM_DLL Any(); + TVM_DLL Any(Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); }; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index caddd99eeb2c..64dbb5cf8ec3 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -140,10 +140,11 @@ class PrimFunc : public BaseFunc { * \param ret_type The return type of the function. * \param buffer_map The buffer map for parameter buffer unpacking. * \param attrs Additional function attributes. + * \param span The location of this object in the source code. */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = Map(), - DictAttrs attrs = NullValue()); + DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 16800d57bda8..661c30110062 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -37,6 +37,15 @@ namespace tir { /*! \brief Base node of all statements. */ class StmtNode : public Object { public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + StmtNode() = default; + explicit StmtNode(Span span) : span(span) {} + static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -89,7 +98,7 @@ class LetStmtNode : public StmtNode { */ class LetStmt : public Stmt { public: - TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body); + TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode); }; @@ -144,7 +153,7 @@ class AttrStmtNode : public StmtNode { */ class AttrStmt : public Stmt { public: - TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body); + TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); }; @@ -191,7 +200,7 @@ class AssertStmtNode : public StmtNode { */ class AssertStmt : public Stmt { public: - TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body); + TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode); }; @@ -254,7 +263,8 @@ class StoreNode : public StmtNode { */ class Store : public Stmt { public: - TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate); + TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode); }; @@ -305,7 +315,8 @@ class BufferStoreNode : public StmtNode { */ class BufferStore : public Stmt { public: - TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices); + TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; @@ -352,8 +363,9 @@ class BufferRealizeNode : public StmtNode { } BufferRealizeNode() = default; - BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) - : buffer(buffer), bounds(bounds), condition(condition), body(body) {} + BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, + Span span = Span()) + : StmtNode(span), buffer(buffer), bounds(bounds), condition(condition), body(body) {} static constexpr const char* _type_key = "tir.BufferRealize"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); @@ -365,7 +377,8 @@ class BufferRealizeNode : public StmtNode { */ class BufferRealize : public Stmt { public: - TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body); + TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, + Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); }; @@ -416,7 +429,8 @@ class ProducerStoreNode : public StmtNode { */ class ProducerStore : public Stmt { public: - TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array indices); + TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array indices, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); }; @@ -472,7 +486,8 @@ class ProducerRealizeNode : public StmtNode { */ class ProducerRealize : public Stmt { public: - TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body); + TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; @@ -540,7 +555,7 @@ class AllocateNode : public StmtNode { class Allocate : public Stmt { public: TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body); + Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); }; @@ -579,8 +594,9 @@ class SeqStmt : public Stmt { /*! * \brief Construct SeqStmt. * \param seq The sequence. + * \param span The location of this object in the source code. */ - TVM_DLL explicit SeqStmt(Array seq); + TVM_DLL explicit SeqStmt(Array seq, Span span = Span()); /*! \return get the size of the sequence */ size_t size() const { return operator->()->size(); } @@ -678,7 +694,8 @@ class IfThenElseNode : public StmtNode { */ class IfThenElse : public Stmt { public: - TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt(), + Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); }; @@ -712,9 +729,9 @@ class EvaluateNode : public StmtNode { */ class Evaluate : public Stmt { public: - TVM_DLL explicit Evaluate(PrimExpr value); + TVM_DLL explicit Evaluate(PrimExpr value, Span span = Span()); - explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {} + explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {} TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); }; @@ -799,7 +816,7 @@ class ForNode : public StmtNode { class For : public Stmt { public: TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, - Stmt body); + Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); }; @@ -829,7 +846,8 @@ class PrefetchNode : public StmtNode { } PrefetchNode() = default; - PrefetchNode(Buffer buffer, Array bounds) : buffer(buffer), bounds(bounds) {} + PrefetchNode(Buffer buffer, Array bounds, Span span = Span()) + : StmtNode(span), buffer(buffer), bounds(bounds) {} static constexpr const char* _type_key = "tir.Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); @@ -841,7 +859,7 @@ class PrefetchNode : public StmtNode { */ class Prefetch : public Stmt { public: - TVM_DLL explicit Prefetch(Buffer buffer, Array bounds); + TVM_DLL explicit Prefetch(Buffer buffer, Array bounds, Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); }; @@ -973,9 +991,10 @@ inline bool IsPragmaKey(const std::string& attr_key) { /*! * \brief Create a type annotation expression * \param dtype The data type + * \param span The location of this object in the source code. * \return Expr a expression with dtype. */ -TVM_DLL PrimExpr TypeAnnotation(DataType dtype); +TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index f1651c118010..a2240939ddea 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -91,14 +91,17 @@ class Var : public PrimExpr { * \brief Constructor * \param name_hint variable name * \param dtype data type + * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32)); + TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32), + Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. * \param name_hint variable name. * \param type_annotation The type annotation. + * \param span The location of this object in the source code. */ - TVM_DLL explicit Var(String name_hint, Type type_annotation); + TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span()); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. @@ -138,8 +141,10 @@ class SizeVar : public Var { * \brief constructor * \param name_hint variable name * \param t data type + * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32)); + TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32), + Span span = Span()); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. @@ -246,6 +251,11 @@ class IterVarNode : public Object { * set this if this is binded already to a known thread tag. */ String thread_tag; + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; void VisitAttrs(AttrVisitor* v) { v->Visit("dom", &dom); @@ -278,7 +288,8 @@ class IterVarNode : public Object { */ class IterVar : public ObjectRef { public: - TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = ""); + TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = "", + Span span = Span()); /*! * \return the corresponding var in the IterVar. */ diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 60d92e901764..ca46981acdb9 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -325,10 +325,13 @@ class Var(PrimExprWithOp): dtype : Union[str, tvm.irType] The data type + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, name, dtype): - self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype) + def __init__(self, name, dtype, span=None): + self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) @tvm._ffi.register_object("tir.SizeVar") @@ -343,11 +346,14 @@ class SizeVar(Var): dtype : int The data type + + span : Optional[Span] + The location of this itervar in the source code. """ # pylint: disable=super-init-not-called - def __init__(self, name, dtype): - self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype) + def __init__(self, name, dtype, span=None): + self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) @tvm._ffi.register_object("tir.IterVar") @@ -370,6 +376,9 @@ class IterVar(Object, ExprOp): thread_tag : str The thread type tag. + span : Optional[Span] + The location of this itervar in the source code. + See Also -------- te.thread_axis: Create thread axis IterVar. @@ -386,7 +395,7 @@ class IterVar(Object, ExprOp): Parallelized = 7 Tensorized = 8 - def __init__(self, dom, var, iter_type, thread_tag=""): + def __init__(self, dom, var, iter_type, thread_tag="", span=None): if dom is not None: if isinstance(dom, (list, tuple)): if len(dom) != 2: @@ -399,7 +408,7 @@ def __init__(self, dom, var, iter_type, thread_tag=""): name = var if var is not None else "iter" dtype = "int32" if dom is None else dom.extent.dtype var = Var(name, dtype=dtype) if not isinstance(var, Var) else var - self.__init_handle_by_constructor__(_ffi_api.IterVar, dom, var, iter_type, thread_tag) + self.__init_handle_by_constructor__(_ffi_api.IterVar, dom, var, iter_type, thread_tag, span) @tvm._ffi.register_object("tir.CommReducer") @@ -419,11 +428,14 @@ class CommReducer(Object): identity_element : List[PrimExpr] The identity elements. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, lhs, rhs, result, identity_element): + def __init__(self, lhs, rhs, result, identity_element, span=None): self.__init_handle_by_constructor__( - _ffi_api.CommReducer, lhs, rhs, result, identity_element + _ffi_api.CommReducer, lhs, rhs, result, identity_element, span ) @@ -450,11 +462,14 @@ class Reduce(PrimExprWithOp): init : list of Expr The initial value for output. This can be an int, float or ProducerLoad + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, combiner, src, rdom, condition, value_index, init=None): + def __init__(self, combiner, src, rdom, condition, value_index, init=None, span=None): self.__init_handle_by_constructor__( - _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init + _ffi_api.Reduce, combiner, src, rdom, condition, value_index, init, span ) @@ -469,10 +484,13 @@ class FloatImm(ConstExpr): value : float The constant value. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__(tvm.ir._ffi_api.FloatImm, dtype, value) + def __init__(self, dtype, value, span=None): + self.__init_handle_by_constructor__(tvm.ir._ffi_api.FloatImm, dtype, value, span) @tvm._ffi.register_object @@ -486,10 +504,13 @@ class IntImm(ConstExpr): value : int The constant value. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__(tvm.ir._ffi_api.IntImm, dtype, value) + def __init__(self, dtype, value, span=None): + self.__init_handle_by_constructor__(tvm.ir._ffi_api.IntImm, dtype, value, span) def __hash__(self): return self.value @@ -518,10 +539,13 @@ class StringImm(ConstExpr): ---------- value : str The value of the function. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, value): - self.__init_handle_by_constructor__(_ffi_api.StringImm, value) + def __init__(self, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span) def __eq__(self, other): if isinstance(other, ConstExpr): @@ -545,10 +569,13 @@ class Cast(PrimExprWithOp): value : PrimExpr The value of the function. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, dtype, value): - self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value) + def __init__(self, dtype, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) @tvm._ffi.register_object("tir.Add") @@ -562,10 +589,13 @@ class Add(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Add, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) @tvm._ffi.register_object("tir.Sub") @@ -579,10 +609,13 @@ class Sub(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Sub, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) @tvm._ffi.register_object("tir.Mul") @@ -596,10 +629,13 @@ class Mul(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Mul, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) @tvm._ffi.register_object("tir.Div") @@ -613,10 +649,13 @@ class Div(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Div, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) @tvm._ffi.register_object("tir.Mod") @@ -630,10 +669,13 @@ class Mod(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Mod, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) @tvm._ffi.register_object("tir.FloorDiv") @@ -647,10 +689,13 @@ class FloorDiv(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) @tvm._ffi.register_object("tir.FloorMod") @@ -664,10 +709,13 @@ class FloorMod(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) @tvm._ffi.register_object("tir.Min") @@ -681,10 +729,13 @@ class Min(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Min, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) @tvm._ffi.register_object("tir.Max") @@ -698,10 +749,13 @@ class Max(BinaryOpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Max, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) @tvm._ffi.register_object("tir.EQ") @@ -715,10 +769,13 @@ class EQ(CmpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.EQ, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) @tvm._ffi.register_object("tir.NE") @@ -732,10 +789,13 @@ class NE(CmpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.NE, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) @tvm._ffi.register_object("tir.LT") @@ -749,10 +809,13 @@ class LT(CmpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.LT, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) @tvm._ffi.register_object("tir.LE") @@ -766,10 +829,13 @@ class LE(CmpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.LE, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) @tvm._ffi.register_object("tir.GT") @@ -783,10 +849,13 @@ class GT(CmpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.GT, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) @tvm._ffi.register_object("tir.GE") @@ -800,10 +869,13 @@ class GE(CmpExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.GE, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) @tvm._ffi.register_object("tir.And") @@ -817,10 +889,13 @@ class And(LogicalExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.And, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) @tvm._ffi.register_object("tir.Or") @@ -834,10 +909,13 @@ class Or(LogicalExpr): b : PrimExpr The right hand operand. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a, b): - self.__init_handle_by_constructor__(_ffi_api.Or, a, b) + def __init__(self, a, b, span=None): + self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) @tvm._ffi.register_object("tir.Not") @@ -848,10 +926,13 @@ class Not(LogicalExpr): ---------- a : PrimExpr The input value + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, a): - self.__init_handle_by_constructor__(_ffi_api.Not, a) + def __init__(self, a, span=None): + self.__init_handle_by_constructor__(_ffi_api.Not, a, span) @tvm._ffi.register_object("tir.Select") @@ -876,10 +957,14 @@ class Select(PrimExprWithOp): false_value : PrimExpr The value to take when condition is false. + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, condition, true_value, false_value): - self.__init_handle_by_constructor__(_ffi_api.Select, condition, true_value, false_value) + def __init__(self, condition, true_value, false_value, span=None): + self.__init_handle_by_constructor__( + _ffi_api.Select, condition, true_value, false_value, span + ) @tvm._ffi.register_object("tir.Load") @@ -899,11 +984,17 @@ class Load(PrimExprWithOp): predicate : PrimExpr The load predicate. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, dtype, buffer_var, index, predicate=None): - args = [] if predicate is None else [predicate] - self.__init_handle_by_constructor__(_ffi_api.Load, dtype, buffer_var, index, *args) + def __init__(self, dtype, buffer_var, index, predicate=None, span=None): + if predicate is None: + predicate = _ffi_api.const_true(dtype) + self.__init_handle_by_constructor__( + _ffi_api.Load, dtype, buffer_var, index, predicate, span + ) @tvm._ffi.register_object("tir.BufferLoad") @@ -917,10 +1008,13 @@ class BufferLoad(PrimExprWithOp): indices : List[PrimExpr] The buffer indices. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, buffer, indices): - self.__init_handle_by_constructor__(_ffi_api.BufferLoad, buffer, indices) + def __init__(self, buffer, indices, span=None): + self.__init_handle_by_constructor__(_ffi_api.BufferLoad, buffer, indices, span) @tvm._ffi.register_object("tir.ProducerLoad") @@ -934,10 +1028,13 @@ class ProducerLoad(PrimExprWithOp): indices : List[PrimExpr] The buffer indices. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, producer, indices): - self.__init_handle_by_constructor__(_ffi_api.ProducerLoad, producer, indices) + def __init__(self, producer, indices, span=None): + self.__init_handle_by_constructor__(_ffi_api.ProducerLoad, producer, indices, span) @tvm._ffi.register_object("tir.Ramp") @@ -954,10 +1051,13 @@ class Ramp(PrimExprWithOp): lanes : int The lanes of the expression. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, base, stride, lanes): - self.__init_handle_by_constructor__(_ffi_api.Ramp, base, stride, lanes) + def __init__(self, base, stride, lanes, span=None): + self.__init_handle_by_constructor__(_ffi_api.Ramp, base, stride, lanes, span) @tvm._ffi.register_object("tir.Broadcast") @@ -971,10 +1071,13 @@ class Broadcast(PrimExprWithOp): lanes : int The lanes of the expression. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, value, lanes): - self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes) + def __init__(self, value, lanes, span=None): + self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) @tvm._ffi.register_object("tir.Shuffle") @@ -988,10 +1091,13 @@ class Shuffle(PrimExprWithOp): indices : Array of indices The indices + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, vectors, indices): - self.__init_handle_by_constructor__(_ffi_api.Shuffle, vectors, indices) + def __init__(self, vectors, indices, span=None): + self.__init_handle_by_constructor__(_ffi_api.Shuffle, vectors, indices, span) class CallEffectKind: @@ -1020,9 +1126,12 @@ class Call(PrimExprWithOp): args : list of Expr The input arguments to the call + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, dtype, op, args): + def __init__(self, dtype, op, args, span=None): if isinstance(op, str): if not op.startswith("tir."): raise ValueError( @@ -1034,7 +1143,7 @@ def __init__(self, dtype, op, args): % op ) op = Op.get(op) - self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args) + self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) @tvm._ffi.register_object("tir.Let") @@ -1051,15 +1160,22 @@ class Let(PrimExprWithOp): body : PrimExpr The body expression. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, var, value, body): - self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body) + def __init__(self, var, value, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, span) @tvm._ffi.register_object("tir.Any") class Any(PrimExpr): - """Any node.""" + """Any node. + + span : Optional[Span] + The location of this itervar in the source code. + """ - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.Any) + def __init__(self, span=None): + self.__init_handle_by_constructor__(_ffi_api.Any, span) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index b02ebba18765..79d18d8970b5 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -45,9 +45,12 @@ class PrimFunc(BaseFunc): attrs: Optional[tvm.Attrs] Attributes of the function, can be None + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None): + def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, span=None): param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: @@ -62,10 +65,10 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None): raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs + _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span ) - def with_body(self, new_body): + def with_body(self, new_body, span=None): """Create a new PrimFunc with the same set signatures but a new body. Parameters @@ -73,9 +76,12 @@ def with_body(self, new_body): new_body : Stmt The new body. + span : Optional[Span] + The location of this itervar in the source code. + Returns ------- new_func : PrimFunc The created new function. """ - return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs) + return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 573bc0e7d970..cba4ce337b1d 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -50,10 +50,13 @@ class LetStmt(Stmt): body : Stmt The body statement. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, var, value, body): - self.__init_handle_by_constructor__(_ffi_api.LetStmt, var, value, body) + def __init__(self, var, value, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.LetStmt, var, value, body, span) @tvm._ffi.register_object("tir.AssertStmt") @@ -70,10 +73,13 @@ class AssertStmt(Stmt): body : Stmt The body statement. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, condition, message, body): - self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body) + def __init__(self, condition, message, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body, span) @tvm._ffi.register_object("tir.For") @@ -99,6 +105,9 @@ class For(Stmt): body : Stmt The body statement. + + span : Optional[Span] + The location of this itervar in the source code. """ Serial = 0 @@ -106,9 +115,9 @@ class For(Stmt): Vectorized = 2 Unrolled = 3 - def __init__(self, loop_var, min_val, extent, for_type, device_api, body): + def __init__(self, loop_var, min_val, extent, for_type, device_api, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.For, loop_var, min_val, extent, for_type, device_api, body + _ffi_api.For, loop_var, min_val, extent, for_type, device_api, body, span ) @@ -129,11 +138,17 @@ class Store(Stmt): predicate : PrimExpr The store predicate. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, buffer_var, value, index, predicate=None): - args = [] if predicate is None else [predicate] - self.__init_handle_by_constructor__(_ffi_api.Store, buffer_var, value, index, *args) + def __init__(self, buffer_var, value, index, predicate=None, span=None): + if predicate is None: + predicate = _ffi_api.const_true(value.dtype) + self.__init_handle_by_constructor__( + _ffi_api.Store, buffer_var, value, index, predicate, span + ) @tvm._ffi.register_object("tir.BufferStore") @@ -150,10 +165,13 @@ class BufferStore(Stmt): indices : List[PrimExpr] The indices location to be stored. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, buffer, value, indices): - self.__init_handle_by_constructor__(_ffi_api.BufferStore, buffer, value, indices) + def __init__(self, buffer, value, indices, span=None): + self.__init_handle_by_constructor__(_ffi_api.BufferStore, buffer, value, indices, span) @tvm._ffi.register_object("tir.BufferRealize") @@ -173,10 +191,15 @@ class BufferRealize(Stmt): body : Stmt The body of the statement. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, buffer, bounds, condition, body): - self.__init_handle_by_constructor__(_ffi_api.BufferRealize, buffer, bounds, condition, body) + def __init__(self, buffer, bounds, condition, body, span=None): + self.__init_handle_by_constructor__( + _ffi_api.BufferRealize, buffer, bounds, condition, body, span + ) @tvm._ffi.register_object("tir.ProducerStore") @@ -193,10 +216,13 @@ class ProducerStore(Stmt): indices : list of Expr The index arguments of the store. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, producer, value, indices): - self.__init_handle_by_constructor__(_ffi_api.ProducerStore, producer, value, indices) + def __init__(self, producer, value, indices, span=None): + self.__init_handle_by_constructor__(_ffi_api.ProducerStore, producer, value, indices, span) @tvm._ffi.register_object("tir.Allocate") @@ -219,11 +245,14 @@ class Allocate(Stmt): body : Stmt The body statement. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, buffer_var, dtype, extents, condition, body): + def __init__(self, buffer_var, dtype, extents, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.Allocate, buffer_var, dtype, extents, condition, body + _ffi_api.Allocate, buffer_var, dtype, extents, condition, body, span ) @@ -244,10 +273,13 @@ class AttrStmt(Stmt): body : Stmt The body statement. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, node, attr_key, value, body): - self.__init_handle_by_constructor__(_ffi_api.AttrStmt, node, attr_key, value, body) + def __init__(self, node, attr_key, value, body, span=None): + self.__init_handle_by_constructor__(_ffi_api.AttrStmt, node, attr_key, value, body, span) @tvm._ffi.register_object("tir.ProducerRealize") @@ -267,11 +299,14 @@ class ProducerRealize(Stmt): body : Stmt The realize body + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, producer, bounds, condition, body): + def __init__(self, producer, bounds, condition, body, span=None): self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, producer, bounds, condition, body + _ffi_api.ProducerRealize, producer, bounds, condition, body, span ) @@ -283,10 +318,13 @@ class SeqStmt(Stmt): ---------- seq : List[Stmt] The statements + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, seq): - self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq) + def __init__(self, seq, span=None): + self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span) def __getitem__(self, i): return self.seq[i] @@ -309,10 +347,15 @@ class IfThenElse(Stmt): else_case : Stmt The statement to execute if condition is false. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, condition, then_case, else_case): - self.__init_handle_by_constructor__(_ffi_api.IfThenElse, condition, then_case, else_case) + def __init__(self, condition, then_case, else_case, span=None): + self.__init_handle_by_constructor__( + _ffi_api.IfThenElse, condition, then_case, else_case, span + ) @tvm._ffi.register_object("tir.Evaluate") @@ -323,10 +366,13 @@ class Evaluate(Stmt): ---------- value : PrimExpr The expression to be evalued. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, value): - self.__init_handle_by_constructor__(_ffi_api.Evaluate, value) + def __init__(self, value, span=None): + self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) @tvm._ffi.register_object("tir.Prefetch") @@ -340,10 +386,13 @@ class Prefetch(Stmt): bounds : list of Range The bounds to be prefetched. + + span : Optional[Span] + The location of this itervar in the source code. """ - def __init__(self, buffer, bounds): - self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds) + def __init__(self, buffer, bounds, span=None): + self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) def stmt_seq(*args): diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 67e5cea93011..0b7049ec212b 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -55,20 +55,23 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { return Downcast(ref); } -IntImm::IntImm(DataType dtype, int64_t value) { - ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; - ICHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm supports only int or uint type."; +IntImm::IntImm(DataType dtype, int64_t value, Span span) { + ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype + << " was supplied."; + ICHECK(dtype.is_int() || dtype.is_uint()) + << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; if (dtype.is_uint()) { ICHECK_GE(value, 0U); } ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; + node->span = span; data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value) { - return IntImm(dtype, value); +TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value, Span span) { + return IntImm(dtype, value, span); }); TVM_REGISTER_NODE_TYPE(IntImmNode); @@ -83,16 +86,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); -FloatImm::FloatImm(DataType dtype, double value) { +FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; + node->span = span; data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value) { - return FloatImm(dtype, value); +TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value, Span span) { + return FloatImm(dtype, value, span); }); TVM_REGISTER_NODE_TYPE(FloatImmNode); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index fd55f2418628..faa483d019c0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1154,7 +1154,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { } void CodeGenLLVM::VisitStmt_(const StoreNode* op) { - ICHECK(is_one(op->predicate)); + ICHECK(is_one(op->predicate)) << op->predicate; DataType t = op->value.dtype(); bool is_volatile = volatile_buf_.count(op->buffer_var.get()); llvm::Value* buffer = MakeValue(op->buffer_var); diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 08b2224e9912..7db49093e596 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,9 +45,9 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name) { - return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), - PrimExpr(), name, "", 0, 0, kDefault); +Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { + return Buffer(Var(name, PointerType(PrimType(dtype)), span), dtype, shape, Array(), + PrimExpr(), name, "", 0, 0, kDefault, span); } // Split the given expression w.r.t the add operator @@ -382,7 +382,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, PrimExpr elem_offset, String name, String scope, int data_alignment, - int offset_factor, BufferType buffer_type) { + int offset_factor, BufferType buffer_type, Span span) { ICHECK(IsPointerType(data->type_annotation, dtype)) << "Buffer data field expect to have the right pointer type annotation" << " annotation=" << data->type_annotation << ", dtype=" << dtype; @@ -416,6 +416,7 @@ Buffer::Buffer(Var data, DataType dtype, Array shape, Array n->strides.push_back(Var("stride", n->shape[i].dtype())); } } + n->span = std::move(span); data_ = std::move(n); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 825bac86919c..2d2a29943383 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -34,7 +34,7 @@ namespace tvm { namespace tir { #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b) { \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ @@ -43,11 +43,12 @@ namespace tir { node->dtype = a.dtype(); \ node->a = std::move(a); \ node->b = std::move(b); \ + node->span = std::move(span); \ data_ = std::move(node); \ } #define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b) { \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ @@ -56,22 +57,25 @@ namespace tir { node->dtype = DataType::Bool(a.dtype().lanes()); \ node->a = std::move(a); \ node->b = std::move(b); \ + node->span = std::move(span); \ data_ = std::move(node); \ } // Var -Var::Var(String name_hint, DataType dtype) { +Var::Var(String name_hint, DataType dtype, Span span) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = std::move(dtype); + n->span = std::move(span); data_ = std::move(n); } -Var::Var(String name_hint, Type type_annotation) { +Var::Var(String name_hint, Type type_annotation, Span span) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); n->type_annotation = std::move(type_annotation); + n->span = std::move(span); data_ = std::move(n); } @@ -87,11 +91,12 @@ Var Var::copy_with_suffix(const String& suffix) const { return Var(new_ptr); } -TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) { +TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type, + Span span) { if (type.IsObjectRef()) { - return Var(name_hint, type.operator Type()); + return Var(name_hint, type.operator Type(), span); } else { - return Var(name_hint, type.operator DataType()); + return Var(name_hint, type.operator DataType(), span); } }); @@ -106,15 +111,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // SizeVar -SizeVar::SizeVar(String name_hint, DataType dtype) { +SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = std::move(dtype); + n->span = std::move(span); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) { - return SizeVar(s, t); +TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { + return SizeVar(s, t, span); }); TVM_REGISTER_NODE_TYPE(SizeVarNode); @@ -126,18 +132,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // IterVar -IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag) { +IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) { ObjectPtr n = make_object(); n->dom = dom; n->var = var; n->iter_type = t; n->thread_tag = thread_tag; + n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.IterVar") - .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag) { - return IterVar(dom, var, static_cast(iter_type), thread_tag); + .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag, Span span) { + return IterVar(dom, var, static_cast(iter_type), thread_tag, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -159,14 +166,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(IterVarNode); // StringImm -StringImm::StringImm(String value) { +StringImm::StringImm(String value, Span span) { ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value) { return StringImm(value); }); +TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) { + return StringImm(value, span); +}); TVM_REGISTER_NODE_TYPE(StringImmNode); @@ -177,17 +187,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Cast -Cast::Cast(DataType t, PrimExpr value) { +Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); ICHECK_EQ(t.lanes(), value.dtype().lanes()); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value) { - return Cast(dtype, value); +TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value, Span span) { + return Cast(dtype, value, span); }); TVM_REGISTER_NODE_TYPE(CastNode); @@ -203,7 +214,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Add TVM_DEFINE_BINOP_CONSTRUCTOR(Add); -TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b) { return Add(a, b); }); +TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Add(a, b, span); +}); TVM_REGISTER_NODE_TYPE(AddNode); @@ -220,7 +233,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Sub TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); -TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b) { return Sub(a, b); }); +TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Sub(a, b, span); +}); TVM_REGISTER_NODE_TYPE(SubNode); @@ -237,7 +252,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Mul TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); -TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b) { return Mul(a, b); }); +TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Mul(a, b, span); +}); TVM_REGISTER_NODE_TYPE(MulNode); @@ -254,7 +271,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Div TVM_DEFINE_BINOP_CONSTRUCTOR(Div); -TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b) { return Div(a, b); }); +TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Div(a, b, span); +}); TVM_REGISTER_NODE_TYPE(DivNode); @@ -271,7 +290,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Mod TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); -TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b) { return Mod(a, b); }); +TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Mod(a, b, span); +}); TVM_REGISTER_NODE_TYPE(ModNode); @@ -288,8 +309,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // FloorDiv TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); -TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b) { - return FloorDiv(a, b); +TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return FloorDiv(a, b, span); }); TVM_REGISTER_NODE_TYPE(FloorDivNode); @@ -303,8 +324,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // FloorMod TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); -TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b) { - return FloorMod(a, b); +TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return FloorMod(a, b, span); }); TVM_REGISTER_NODE_TYPE(FloorModNode); @@ -318,7 +339,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Min TVM_DEFINE_BINOP_CONSTRUCTOR(Min); -TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b) { return Min(a, b); }); +TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Min(a, b, span); +}); TVM_REGISTER_NODE_TYPE(MinNode); @@ -335,7 +358,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Max TVM_DEFINE_BINOP_CONSTRUCTOR(Max); -TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b) { return Max(a, b); }); +TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Max(a, b, span); +}); TVM_REGISTER_NODE_TYPE(MaxNode); @@ -352,7 +377,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // EQ TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); -TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b) { return EQ(a, b); }); +TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return EQ(a, b, span); +}); TVM_REGISTER_NODE_TYPE(EQNode); @@ -369,7 +396,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // NE TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); -TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b) { return NE(a, b); }); +TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return NE(a, b, span); +}); TVM_REGISTER_NODE_TYPE(NENode); @@ -386,7 +415,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // LT TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); -TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b) { return LT(a, b); }); +TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return LT(a, b, span); +}); TVM_REGISTER_NODE_TYPE(LTNode); @@ -403,7 +434,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // LE TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); -TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b) { return LE(a, b); }); +TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return LE(a, b, span); +}); TVM_REGISTER_NODE_TYPE(LENode); @@ -420,7 +453,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // GT TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); -TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b) { return GT(a, b); }); +TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return GT(a, b, span); +}); TVM_REGISTER_NODE_TYPE(GTNode); @@ -437,7 +472,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // GE TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); -TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b) { return GE(a, b); }); +TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return GE(a, b, span); +}); TVM_REGISTER_NODE_TYPE(GENode); @@ -452,7 +489,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // And -And::And(PrimExpr a, PrimExpr b) { +And::And(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.defined()) << "ValueError: a is undefined"; ICHECK(b.defined()) << "ValueError: b is undefined"; ICHECK(a.dtype().is_bool()); @@ -463,10 +500,13 @@ And::And(PrimExpr a, PrimExpr b) { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b) { return And(a, b); }); +TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return And(a, b, span); +}); TVM_REGISTER_NODE_TYPE(AndNode); @@ -481,7 +521,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Or -Or::Or(PrimExpr a, PrimExpr b) { +Or::Or(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.defined()) << "ValueError: a is undefined"; ICHECK(b.defined()) << "ValueError: b is undefined"; ICHECK(a.dtype().is_bool()); @@ -492,10 +532,13 @@ Or::Or(PrimExpr a, PrimExpr b) { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b) { return Or(a, b); }); +TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return Or(a, b, span); +}); TVM_REGISTER_NODE_TYPE(OrNode); @@ -510,17 +553,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Not -Not::Not(PrimExpr a) { +Not::Not(PrimExpr a, Span span) { ICHECK(a.defined()) << "ValueError: a is undefined"; ICHECK(a.dtype().is_bool()); ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a) { return Not(a); }); +TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { return Not(a, span); }); TVM_REGISTER_NODE_TYPE(NotNode); @@ -532,7 +576,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Select -Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { +Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { ICHECK(condition.defined()) << "ValueError: condition is undefined"; ICHECK(true_value.defined()) << "ValueError: true_value is undefined"; ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; @@ -545,12 +589,13 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { node->condition = std::move(condition); node->true_value = std::move(true_value); node->false_value = std::move(false_value); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Select") - .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { - return Select(condition, true_value, false_value); + .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { + return Select(condition, true_value, false_value, span); }); TVM_REGISTER_NODE_TYPE(SelectNode); @@ -568,7 +613,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Load -Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { +Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) { ICHECK(buffer_var.defined()); ICHECK(predicate.defined()); ICHECK(index.defined()); @@ -580,6 +625,7 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { node->buffer_var = std::move(buffer_var); node->index = std::move(index); node->predicate = std::move(predicate); + node->span = std::move(span); data_ = std::move(node); } @@ -587,9 +633,11 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { DataType t = args[0]; if (args.size() == 3) { - *ret = Load(t, args[1], args[2], const_true(t.lanes())); + *ret = Load(t, args[1], args[2], const_true(t.lanes()), Span()); + } else if (args.size() == 4) { + *ret = Load(t, args[1], args[2], args[3], Span()); } else { - *ret = Load(t, args[1], args[2], args[3]); + *ret = Load(t, args[1], args[2], args[3], args[4]); } }); @@ -608,7 +656,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Ramp -Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes) { +Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { ICHECK(base.defined()); ICHECK(stride.defined()); ICHECK(base.dtype().is_scalar()); @@ -621,12 +669,14 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes) { node->base = base; node->stride = stride; node->lanes = lanes; + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed([](PrimExpr base, PrimExpr stride, int lanes) { - return Ramp(base, stride, lanes); -}); +TVM_REGISTER_GLOBAL("tir.Ramp") + .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) { + return Ramp(base, stride, lanes, span); + }); TVM_REGISTER_NODE_TYPE(RampNode); @@ -641,7 +691,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Broadcast -Broadcast::Broadcast(PrimExpr value, int lanes) { +Broadcast::Broadcast(PrimExpr value, int lanes, Span span) { ICHECK(value.defined()); ICHECK(value.dtype().is_scalar()); ICHECK_GT(lanes, 1); @@ -650,11 +700,12 @@ Broadcast::Broadcast(PrimExpr value, int lanes) { node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; + node->span = std::move(span); data_ = node; } -TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes) { - return Broadcast(value, lanes); +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes, Span span) { + return Broadcast(value, lanes, span); }); TVM_REGISTER_NODE_TYPE(BroadcastNode); @@ -668,7 +719,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Let -Let::Let(Var var, PrimExpr value, PrimExpr body) { +Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { ICHECK(value.defined()); ICHECK(body.defined()); ICHECK_EQ(value.dtype(), var.dtype()); @@ -678,11 +729,13 @@ Let::Let(Var var, PrimExpr value, PrimExpr body) { node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body) { - return Let(var, value, body); +TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body, + Span span) { + return Let(var, value, body, span); }); TVM_REGISTER_NODE_TYPE(LetNode); @@ -698,7 +751,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Call -Call::Call(DataType dtype, RelayExpr op, Array args) { +Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { ICHECK(args[i].defined()); } @@ -707,11 +760,12 @@ Call::Call(DataType dtype, RelayExpr op, Array args) { node->dtype = dtype; node->op = std::move(op); node->args = std::move(args); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args) { + .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance()); @@ -721,7 +775,7 @@ TVM_REGISTER_GLOBAL("tir.Call") prim_expr_args.push_back(Downcast(it)); } } - return Call(type, op, prim_expr_args); + return Call(type, op, prim_expr_args, span); }); TVM_REGISTER_NODE_TYPE(CallNode); @@ -746,7 +800,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Shuffle -Shuffle::Shuffle(Array vectors, Array indices) { +Shuffle::Shuffle(Array vectors, Array indices, Span span) { ICHECK_NE(vectors.size(), 0U); ICHECK_NE(indices.size(), 0U); @@ -763,10 +817,11 @@ Shuffle::Shuffle(Array vectors, Array indices) { node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); + node->span = std::move(span); data_ = node; } -PrimExpr Shuffle::Concat(Array vectors) { +PrimExpr Shuffle::Concat(Array vectors, Span span) { ICHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; @@ -778,16 +833,16 @@ PrimExpr Shuffle::Concat(Array vectors) { indices.push_back(IntImm(DataType::Int(32), index++)); } } - return Shuffle(vectors, indices); + return Shuffle(vectors, indices, span); } -PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index) { - return Shuffle({vector}, {Integer(index)}); +PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { + return Shuffle({vector}, {Integer(index)}, span); } TVM_REGISTER_GLOBAL("tir.Shuffle") - .set_body_typed([](Array vectors, Array indices) { - return Shuffle(vectors, indices); + .set_body_typed([](Array vectors, Array indices, Span span) { + return Shuffle(vectors, indices, span); }); TVM_REGISTER_NODE_TYPE(ShuffleNode); @@ -814,12 +869,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // CommReducer CommReducer::CommReducer(Array lhs, Array rhs, Array result, - Array identity_element) { + Array identity_element, Span span) { auto node = make_object(); node->lhs = lhs; node->rhs = rhs; node->result = result; node->identity_element = identity_element; + node->span = std::move(span); data_ = std::move(node); } @@ -839,8 +895,8 @@ Array CommReducerNode::operator()(Array a, Array b TVM_REGISTER_GLOBAL("tir.CommReducer") .set_body_typed([](Array lhs, Array rhs, Array result, - Array identity_element) { - return CommReducer(lhs, rhs, result, identity_element); + Array identity_element, Span span) { + return CommReducer(lhs, rhs, result, identity_element, span); }); TVM_REGISTER_GLOBAL("tir.CommReducerCombine") @@ -857,7 +913,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Reduce Reduce::Reduce(CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init) { + PrimExpr condition, int value_index, Array init, Span span) { for (size_t i = 0; i < axis.size(); ++i) { ICHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } @@ -884,13 +940,14 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis n->axis = std::move(axis); n->condition = condition; n->value_index = value_index; + n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.Reduce") .set_body_typed([](CommReducer combiner, Array source, Array axis, - PrimExpr condition, int value_index, Array init) { - return Reduce(combiner, source, axis, condition, value_index, init); + PrimExpr condition, int value_index, Array init, Span span) { + return Reduce(combiner, source, axis, condition, value_index, init, span); }); TVM_REGISTER_NODE_TYPE(ReduceNode); @@ -908,13 +965,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Any -Any::Any() { +Any::Any(Span span) { auto n = make_object(); n->dtype = DataType::Int(32); + n->span = std::move(span); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([]() { return Any(); }); +TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([](Span span) { return Any(span); }); TVM_REGISTER_NODE_TYPE(AnyNode); @@ -922,17 +980,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); // BufferLoad -BufferLoad::BufferLoad(Buffer buffer, Array indices) { +BufferLoad::BufferLoad(Buffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); node->dtype = buffer->dtype; node->buffer = std::move(buffer); node->indices = std::move(indices); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { - return BufferLoad(buffer, indices); -}); +TVM_REGISTER_GLOBAL("tir.BufferLoad") + .set_body_typed([](Buffer buffer, Array indices, Span span) { + return BufferLoad(buffer, indices, span); + }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); @@ -950,17 +1010,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // ProducerLoad -ProducerLoad::ProducerLoad(DataProducer producer, Array indices) { +ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span span) { ObjectPtr node = make_object(); node->dtype = producer->GetDataType(); node->producer = std::move(producer); node->indices = std::move(indices); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerLoad") - .set_body_typed([](DataProducer producer, Array indices) { - return ProducerLoad(producer, indices); + .set_body_typed([](DataProducer producer, Array indices, Span span) { + return ProducerLoad(producer, indices, span); }); TVM_REGISTER_NODE_TYPE(ProducerLoadNode); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 1149e039cae4..ef7f4f8e16dd 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -30,7 +30,7 @@ namespace tir { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs) { + Map buffer_map, DictAttrs attrs, Span span) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -43,6 +43,7 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->buffer_map = std::move(buffer_map); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); + n->span = std::move(span); data_ = std::move(n); } @@ -73,8 +74,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("tir.PrimFunc") .set_body_typed([](Array params, Stmt body, Type ret_type, - Map buffer_map, DictAttrs attrs) { - return PrimFunc(params, body, ret_type, buffer_map, attrs); + Map buffer_map, DictAttrs attrs, Span span) { + return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); } // namespace tir diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index dbbc99c3abed..86960d9bd999 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -29,7 +29,7 @@ namespace tvm { namespace tir { // LetStmt -LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) { +LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK(value.defined()); ICHECK(body.defined()); ICHECK_EQ(value.dtype(), var.dtype()); @@ -38,12 +38,14 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) { node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed([](Var var, PrimExpr value, Stmt body) { - return LetStmt(var, value, body); -}); +TVM_REGISTER_GLOBAL("tir.LetStmt") + .set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) { + return LetStmt(var, value, body, span); + }); TVM_REGISTER_NODE_TYPE(LetStmtNode); @@ -58,18 +60,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // AttrStmt -AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body) { +AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); n->body = std::move(body); + n->span = std::move(span); data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.AttrStmt") - .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body) { - return AttrStmt(node, attr_key, value, body); + .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + return AttrStmt(node, attr_key, value, body, span); }); TVM_REGISTER_NODE_TYPE(AttrStmtNode); @@ -87,7 +90,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // AssertStmt -AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) { +AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { ICHECK(condition.defined()); ICHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; @@ -96,18 +99,19 @@ AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) { node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(AssertStmtNode); TVM_REGISTER_GLOBAL("tir.AssertStmt") - .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { + .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body, Span span) { if (const auto* str = message.as()) { auto msg = StringImm(str->data); - return AssertStmt(condition, msg, body); + return AssertStmt(condition, msg, body, span); } else { - return AssertStmt(condition, Downcast(message), body); + return AssertStmt(condition, Downcast(message), body, span); } }); @@ -125,7 +129,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, - Stmt body) { + Stmt body, Span span) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); @@ -140,13 +144,15 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAP node->for_type = for_type; node->device_api = device_api; node->body = std::move(body); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body) { + int for_type, int device_api, Stmt body, + Span span) { return For(loop_var, min, extent, static_cast(for_type), - static_cast(device_api), body); + static_cast(device_api), body, span); }); TVM_REGISTER_NODE_TYPE(ForNode); @@ -188,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Store -Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { +Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) { ICHECK(value.defined()); ICHECK(index.defined()); ICHECK(predicate.defined()); @@ -200,15 +206,18 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) node->value = std::move(value); node->index = std::move(index); node->predicate = std::move(predicate); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { PrimExpr value = args[1]; if (args.size() == 3) { - *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes())); + *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes()), Span()); + } else if (args.size() == 4) { + *ret = Store(args[0], value, args[2], args[3], Span()); } else { - *ret = Store(args[0], value, args[2], args[3]); + *ret = Store(args[0], value, args[2], args[3], args[4]); } }); @@ -230,17 +239,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // ProducerStore -ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices) { +ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices, + Span span) { ObjectPtr node = make_object(); node->producer = std::move(producer); node->value = std::move(value); node->indices = std::move(indices); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerStore") - .set_body_typed([](DataProducer producer, PrimExpr value, Array indices) { - return ProducerStore(producer, value, indices); + .set_body_typed([](DataProducer producer, PrimExpr value, Array indices, Span span) { + return ProducerStore(producer, value, indices, span); }); TVM_REGISTER_NODE_TYPE(ProducerStoreNode); @@ -262,7 +273,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body) { + Stmt body, Span span) { // TODO(tvm-team): Add invariant check to make sure // IsPointerPType(buffer_var->type_annotation, dtype) // once we fix the allocate tvm script printing. @@ -280,6 +291,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim node->extents = std::move(extents); node->condition = std::move(condition); node->body = std::move(body); + node->span = std::move(span); data_ = std::move(node); } @@ -300,7 +312,9 @@ int32_t AllocateNode::constant_allocation_size(const Array& extents) { TVM_REGISTER_GLOBAL("tir.Allocate") .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body) { return Allocate(buffer_var, type, extents, condition, body); }); + Stmt body, Span span) { + return Allocate(buffer_var, type, extents, condition, body, span); + }); TVM_REGISTER_NODE_TYPE(AllocateNode); @@ -324,7 +338,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body) { + Stmt body, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -340,12 +354,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->bounds = std::move(bounds); node->condition = std::move(condition); node->body = std::move(body); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") - .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body) { - return ProducerRealize(producer, bounds, condition, body); + .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, + Span span) { + return ProducerRealize(producer, bounds, condition, body, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); @@ -379,13 +395,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Prefetch -Prefetch::Prefetch(Buffer buffer, Array bounds) { - data_ = make_object(buffer, bounds); +Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { + data_ = make_object(buffer, bounds, span); } -TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array bounds) { - return Prefetch(buffer, bounds); -}); +TVM_REGISTER_GLOBAL("tir.Prefetch") + .set_body_typed([](Buffer buffer, Array bounds, Span span) { + return Prefetch(buffer, bounds, span); + }); TVM_REGISTER_NODE_TYPE(PrefetchNode); @@ -406,14 +423,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // SeqStmt -SeqStmt::SeqStmt(Array seq) { +SeqStmt::SeqStmt(Array seq, Span span) { auto node = make_object(); node->seq = std::move(seq); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq) { - return SeqStmt(std::move(seq)); +TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq, Span span) { + return SeqStmt(std::move(seq), span); }); TVM_REGISTER_NODE_TYPE(SeqStmtNode); @@ -427,7 +445,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // IfThenElse -IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) { +IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { ICHECK(condition.defined()); ICHECK(then_case.defined()); // else_case may be null. @@ -435,14 +453,15 @@ IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) { node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(IfThenElseNode); TVM_REGISTER_GLOBAL("tir.IfThenElse") - .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case) { - return IfThenElse(condition, then_case, else_case); + .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { + return IfThenElse(condition, then_case, else_case, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -477,15 +496,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Evaluate -Evaluate::Evaluate(PrimExpr value) { +Evaluate::Evaluate(PrimExpr value, Span span) { ICHECK(value.defined()); ObjectPtr node = make_object(); node->value = std::move(value); + node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value) { return Evaluate(value); }); +TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) { + return Evaluate(value, span); +}); TVM_REGISTER_NODE_TYPE(EvaluateNode); @@ -498,17 +520,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // BufferStore -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices, Span span) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); node->value = std::move(value); node->indices = std::move(indices); + node->span = std::move(span); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { - return BufferStore(buffer, value, indices); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices, Span span) { + return BufferStore(buffer, value, indices, span); }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); @@ -529,14 +552,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // BufferRealize -BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { - data_ = make_object(buffer, bounds, condition, body); +BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, + Span span) { + data_ = make_object(buffer, bounds, condition, body, span); } TVM_REGISTER_GLOBAL("tir.BufferRealize") - .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { - return BufferRealize(buffer, bounds, condition, body); - }); + .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body, + Span span) { return BufferRealize(buffer, bounds, condition, body, span); }); TVM_REGISTER_NODE_TYPE(BufferRealizeNode); @@ -568,9 +591,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); -PrimExpr TypeAnnotation(DataType dtype) { +PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); - return tir::Call(dtype, op, {}); + return tir::Call(dtype, op, {}, span); } TVM_REGISTER_OP("tir.type_annotation") diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 71321d2a3b02..1a6df556876d 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -922,4 +922,8 @@ TVM_REGISTER_GLOBAL("tir._OpIfThenElse") return if_then_else(cond, true_value, false_value); }); +TVM_REGISTER_GLOBAL("tir.const_true").set_body_typed([](DataType t) { + return const_true(t.lanes()); +}); + } // namespace tvm