diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 3b2407491f263..add96d713bf8e 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -56,6 +56,14 @@ class DiagnosticNode : public Object { DiagnosticLevel level; /*! \brief The span at which to report an error. */ Span span; + /*! + * \brief The object location at which to report an error. + * + * The object loc provides a location when span is not always + * available during transformation. The error reporter can + * still pick up loc->span if necessary. + */ + ObjectRef loc; /*! \brief The diagnostic message. */ String message; @@ -84,6 +92,18 @@ class Diagnostic : public ObjectRef { static DiagnosticBuilder Warning(Span span); static DiagnosticBuilder Note(Span span); static DiagnosticBuilder Help(Span span); + // variants uses object location + static DiagnosticBuilder Bug(ObjectRef loc); + static DiagnosticBuilder Error(ObjectRef loc); + static DiagnosticBuilder Warning(ObjectRef loc); + static DiagnosticBuilder Note(ObjectRef loc); + static DiagnosticBuilder Help(ObjectRef loc); + // variants uses object ptr. + static DiagnosticBuilder Bug(const Object* loc); + static DiagnosticBuilder Error(const Object* loc); + static DiagnosticBuilder Warning(const Object* loc); + static DiagnosticBuilder Note(const Object* loc); + static DiagnosticBuilder Help(const Object* loc); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode); }; @@ -102,6 +122,11 @@ class DiagnosticBuilder { /*! \brief The span of the diagnostic. */ Span span; + /*! + * \brief The object location at which to report an error. + */ + ObjectRef loc; + template DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; @@ -115,6 +140,8 @@ class DiagnosticBuilder { DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} + DiagnosticBuilder(DiagnosticLevel level, ObjectRef loc) : level(level), loc(loc) {} + operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } private: diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 6687a28d8c84d..9245ec9c0b2fe 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -66,6 +66,34 @@ DiagnosticBuilder Diagnostic::Help(Span span) { return DiagnosticBuilder(DiagnosticLevel::kHelp, span); } +DiagnosticBuilder Diagnostic::Bug(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kBug, loc); +} + +DiagnosticBuilder Diagnostic::Error(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kError, loc); +} + +DiagnosticBuilder Diagnostic::Warning(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kWarning, loc); +} + +DiagnosticBuilder Diagnostic::Note(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kNote, loc); +} + +DiagnosticBuilder Diagnostic::Help(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kHelp, loc); +} + +DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Error(const Object* loc) { return Error(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(GetRef(loc)); } + /* Diagnostic Renderer */ TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode); @@ -284,7 +312,7 @@ DiagnosticRenderer TerminalRenderer(std::ostream& out) { }); } -TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cout); }); +TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cerr); }); TVM_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); });