From 5f52b1559a9faeb6099b53fae19537476fe05148 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 14 Feb 2023 01:50:02 -0500 Subject: [PATCH] [Diagnostic] Support constructing Diagnostic Error through ObjectRef (#13977) This PR supports creating a diagnostic error from an arbitrary Object. Given that we are bringing the diagnostic error for general uses in the long term, in which case * not every Expr necessarily has a span, * we have well-implemented elegant location-aware printer for an object itself, * we may need to print some object other than Expr, or does even not have a span field, we support diagnostic error with arbitrary object to denote the location. Co-authored-by: Tianqi Chen --- include/tvm/ir/diagnostic.h | 27 +++++++++++++++++++++++++++ src/ir/diagnostic.cc | 30 +++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 3b2407491f263..add96d713bf8e 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -56,6 +56,14 @@ class DiagnosticNode : public Object { DiagnosticLevel level; /*! \brief The span at which to report an error. */ Span span; + /*! + * \brief The object location at which to report an error. + * + * The object loc provides a location when span is not always + * available during transformation. The error reporter can + * still pick up loc->span if necessary. + */ + ObjectRef loc; /*! \brief The diagnostic message. */ String message; @@ -84,6 +92,18 @@ class Diagnostic : public ObjectRef { static DiagnosticBuilder Warning(Span span); static DiagnosticBuilder Note(Span span); static DiagnosticBuilder Help(Span span); + // variants uses object location + static DiagnosticBuilder Bug(ObjectRef loc); + static DiagnosticBuilder Error(ObjectRef loc); + static DiagnosticBuilder Warning(ObjectRef loc); + static DiagnosticBuilder Note(ObjectRef loc); + static DiagnosticBuilder Help(ObjectRef loc); + // variants uses object ptr. + static DiagnosticBuilder Bug(const Object* loc); + static DiagnosticBuilder Error(const Object* loc); + static DiagnosticBuilder Warning(const Object* loc); + static DiagnosticBuilder Note(const Object* loc); + static DiagnosticBuilder Help(const Object* loc); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode); }; @@ -102,6 +122,11 @@ class DiagnosticBuilder { /*! \brief The span of the diagnostic. */ Span span; + /*! + * \brief The object location at which to report an error. + */ + ObjectRef loc; + template DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; @@ -115,6 +140,8 @@ class DiagnosticBuilder { DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} + DiagnosticBuilder(DiagnosticLevel level, ObjectRef loc) : level(level), loc(loc) {} + operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } private: diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 6687a28d8c84d..9245ec9c0b2fe 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -66,6 +66,34 @@ DiagnosticBuilder Diagnostic::Help(Span span) { return DiagnosticBuilder(DiagnosticLevel::kHelp, span); } +DiagnosticBuilder Diagnostic::Bug(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kBug, loc); +} + +DiagnosticBuilder Diagnostic::Error(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kError, loc); +} + +DiagnosticBuilder Diagnostic::Warning(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kWarning, loc); +} + +DiagnosticBuilder Diagnostic::Note(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kNote, loc); +} + +DiagnosticBuilder Diagnostic::Help(ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kHelp, loc); +} + +DiagnosticBuilder Diagnostic::Bug(const Object* loc) { return Bug(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Error(const Object* loc) { return Error(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(GetRef(loc)); } + +DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(GetRef(loc)); } + /* Diagnostic Renderer */ TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode); @@ -284,7 +312,7 @@ DiagnosticRenderer TerminalRenderer(std::ostream& out) { }); } -TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cout); }); +TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cerr); }); TVM_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); });