diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index b9dcd351739c5..9b0c21d1ada6f 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -45,6 +45,7 @@ namespace tvm { using tvm::parser::SourceMap; +using tvm::runtime::TypedPackedFunc; static const char* kTVM_INTERNAL_ERROR_MESSAGE = "An internal invariant was violated during the execution of TVM" \ "please read TVM's error reporting guidelines at discuss.tvm.ai/thread"; @@ -178,6 +179,38 @@ class DiagnosticBuilder { */ class DiagnosticContext; +/*! \brief Display diagnostics in a given display format. + * + * A diagnostic renderer is responsible for converting the + * raw diagnostics into consumable output. + * + * For example the terminal renderer will render a sequence + * of compiler diagnostics to std::out and std::err in + * a human readable form. + */ +class DiagnosticRendererNode : public Object { +public: + IRModule module; + TypedPackedFunc renderer; + + static constexpr const char* _type_key = "DiagnosticRenderer"; + TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object); +}; + +class DiagnosticRenderer : public ObjectRef { + public: + TVM_DLL DiagnosticRenderer(const IRModule& mod, TypedPackedFunc renderer); + + void Render(const DiagnosticContext& ctx); + + DiagnosticRendererNode* operator->() { + CHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode); +}; + class DiagnosticContextNode : public Object { public: /*! \brief The source to report against. */ @@ -186,6 +219,9 @@ class DiagnosticContextNode : public Object { /*! \brief The set of diagnostics to report. */ Array diagnostics; + /*! \brief The renderer set for the context. */ + DiagnosticRenderer renderer; + void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); v->Visit("diagnostics", &diagnostics); @@ -199,18 +235,17 @@ class DiagnosticContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticContextNode, Object); }; -class DiagnosticRenderer; - class DiagnosticContext : public ObjectRef { public: - TVM_DLL DiagnosticContext(const SourceMap& source_map); + TVM_DLL DiagnosticContext(const SourceMap& source_map, const DiagnosticRenderer& renderer); + TVM_DLL static DiagnosticContext Default(const SourceMap& source_map); /*! \brief Emit a diagnostic. */ void Emit(const Diagnostic& diagnostic); /*! \brief Emit a diagnostic. */ void EmitFatal(const Diagnostic& diagnostic); /*! \brief Render the errors and raise a DiagnosticError exception. */ - void Render(const DiagnosticRenderer& renderer); + void Render(); DiagnosticContextNode* operator->() { CHECK(get() != nullptr); @@ -220,35 +255,7 @@ class DiagnosticContext : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(DiagnosticContext, ObjectRef, DiagnosticContextNode); }; -/*! \brief Display diagnostics in a given display format. - * - * A diagnostic renderer is responsible for converting the - * raw diagnostics into consumable output. - * - * For example the terminal renderer will render a sequence - * of compiler diagnostics to std::out and std::err in - * a human readable form. - */ -class DiagnosticRenderer { -public: - IRModule module; - - DiagnosticRenderer(const IRModule& mod) - module(mod) {} - - virtual void Render(const DiagnosticContext& ctx) = 0; -}; - -class TerminalRenderer : public DiagnosticRenderer { - IRModule module; - DiagnosticContext ctx; - std::ostream& ostream; - - TerminalRenderer(const IRModule& mod, std::ostream& ostream) : - DiagnosticRenderer(mod), ostream(ostream) {} - - void Render(const DiagnosticContext& ctx) override; -}; +DiagnosticRenderer TerminalRenderer(const IRModule& mod, std::ostream& ostream); } // namespace tvm #endif // TVM_IR_DIAGNOSTIC_H_ diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index cbe74058b8f5a..81e040d916e32 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -28,6 +28,8 @@ namespace tvm { using tvm::parser::Source; +TVM_REGISTER_OBJECT_TYPE(DiagnosticNode); + Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message) { auto n = make_object(); n->level = level; @@ -56,9 +58,22 @@ DiagnosticBuilder Diagnostic::Help(Span span) { return DiagnosticBuilder(DiagnosticLevel::kHelp, span); } -DiagnosticContext::DiagnosticContext(const SourceMap& source_map) { +TVM_REGISTER_OBJECT_TYPE(DiagnosticRendererNode); + +void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { + (*this)->renderer((*this)->module, ctx); +} + +TVM_REGISTER_OBJECT_TYPE(DiagnosticContextNode); + +void DiagnosticContext::Render() { + (*this)->renderer.Render(*this); +} + +DiagnosticContext::DiagnosticContext(const SourceMap& source_map, const DiagnosticRenderer& renderer) { auto n = make_object(); n->source_map = source_map; + n->renderer = renderer; data_ = std::move(n); } @@ -70,11 +85,7 @@ void DiagnosticContext::Emit(const Diagnostic& diagnostic) { /*! \brief Emit a diagnostic. */ void DiagnosticContext::EmitFatal(const Diagnostic& diagnostic) { Emit(diagnostic); - Render(std::cout); -} - -void DiagnosticContext::Render(const DiagnosticRenderer& renderer) { - renderer->Render(); + Render(); } /*! \brief Generate an error message at a specific line and column with the @@ -125,17 +136,24 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s out << std::endl; } + // TODO(@jroesch): eventually modularize the rendering interface to provide control of how to // format errors. -void TerminalRenderer::Render(const DiagnosticContext& ctx) override { - for (auto diagnostic : this->ctx->diagnostics) { - ReportAt(this->ctx, ostream, diagnostic->span, diagnostic->message); - } +DiagnosticRenderer TerminalRenderer(const IRModule& mod, std::ostream& out) { + return DiagnosticRenderer(mod, [&](const IRModule& mod, const DiagnosticContext& ctx) { + for (auto diagnostic : ctx->diagnostics) { + ReportAt(ctx, out, diagnostic->span, diagnostic->message); + } - if (this->ctx->diagnostics.size()) { + if (ctx->diagnostics.size()) { LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " - << "emitted, please check diagnostic render for output."; + << "emitted, please check diagnostic render for output."; } + }); } +TVM_REGISTER_GLOBAL("diagnostics.DefaultRenderer").set_body_typed([](const IRModule& mod) { + return TerminalRenderer(mod, std::cout); +}); + } diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 7078b4e6c13c7..c1b655be88d4a 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1654,7 +1654,7 @@ std::pair InitDiagnosticContext(const std::string& fi // TODO(@jroesch): I think we need to have a module local source map. auto source_map = SourceMap::Global(); source_map.Add(source); - return { source, DiagnosticContext(source_map) }; + return { source, DiagnosticContext::Default(source_map) }; } IRModule ParseModule(std::string file_name, std::string file_content) { @@ -1670,7 +1670,7 @@ IRModule ParseModule(std::string file_name, std::string file_content) { // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them // here. - parser.diag_ctx->Render(std::cout); + parser.diag_ctx->Render(); return tvm::relay::transform::InferType()(mod); } @@ -1690,7 +1690,7 @@ Expr ParseExpr(std::string file_name, std::string file_content) { // NB(@jroesch): it is very important that we render any errors before we procede // if there were any errors which allow the parser to procede we must render them // here. - parser.diag_ctx->Render(std::cout); + parser.diag_ctx->Render(); return expr; }