Skip to content

Commit

Permalink
[tir] Add line level debug info (apache#13012)
Browse files Browse the repository at this point in the history
* TIR debug info

* Fix location emission

* Comments 1/N (docs, cleanups)

* Remove leaky macro usage

* Add unit test

* Remove dead code

* Add accuracy test

Co-authored-by: driazati <[email protected]>
2 people authored and fzi-peccia committed Mar 27, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent f2bbb7e commit 391b659
Showing 16 changed files with 762 additions and 89 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -271,3 +271,9 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py

# Used in CI to communicate between Python and Jenkins
.docker-image-names/

# Printed TIR code on disk
*.tir

# GDB history file
.gdb_history
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
@@ -496,6 +496,13 @@ TVM_DLL Pass LowerAsyncDMA();
*/
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);

/*!
* \brief Add TIR-printer output as debug information to all ops in the module
* \return The pass.
*/

TVM_DLL Pass InstallDebugSpans();

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
12 changes: 12 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
@@ -1039,3 +1039,15 @@ def InstrumentProfileIntrinsics():
The result pass
"""
return _ffi_api.InstrumentProfileIntrinsics() # type: ignore


def InstallDebugSpans():
"""Add line information from the TIR printer as spans on each statement and
expression.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InstallDebugSpans() # type: ignore
8 changes: 8 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
@@ -603,6 +604,9 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

Array<tvm::transform::Pass> host_pass_list;

runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
@@ -621,6 +625,10 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
host_pass_list.push_back(tir::transform::CombineContextCall());

if (enable_debug) {
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(host_pass_list);
}

1 change: 1 addition & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
@@ -440,6 +440,7 @@ Pass GetPass(const String& pass_name) {
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
VLOG(0) << "Running pass " << pass->Info()->name;
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!pass_ctx.PassEnabled(pass_info)) {
40 changes: 22 additions & 18 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
@@ -280,6 +280,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
: show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}

/*! \brief Output a newline */
virtual Doc NewLine();

/*! \brief Print the node */
Doc Print(const ObjectRef& node);

@@ -290,24 +293,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
*/
bool GetVarName(::tvm::tir::Var v, std::string* s);

private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief meta data context */
TextMetaDataContext* meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

friend class tvm::TextPrinter;

protected:
Doc VisitExpr_(const IntImmNode* op) override;
Doc VisitExpr_(const FloatImmNode* op) override;
Doc VisitExpr_(const StringImmNode* op) override;
@@ -363,6 +349,24 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BlockRealizeNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;

private:
/*! \brief whether show meta data */
bool show_meta_;
/*! \brief meta data context */
TextMetaDataContext* meta_;
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

friend class tvm::TextPrinter;

Doc VisitType_(const PrimTypeNode* node) override;
Doc VisitType_(const PointerTypeNode* node) override;
Doc VisitType_(const TupleTypeNode* node) override;
53 changes: 27 additions & 26 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
@@ -124,7 +124,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
for (const auto& it : op->attrs->dict) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
attr_doc << NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
doc << Doc::Indent(2, attr_doc);
}

@@ -136,8 +136,8 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
const Buffer buf = op->buffer_map[v];
buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
}
buffer_doc << Doc::NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
buffer_doc << NewLine() << "buffers = {";
buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << NewLine()));
doc << Doc::Indent(2, buffer_doc) << "}";
}

@@ -149,26 +149,28 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
}
doc << Doc::Indent(
2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
2, NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
}

doc << PrintBody(op->body);
return doc;
}

Doc TIRTextPrinter::NewLine() { return Doc::NewLine(); }

Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
const auto* op = module.operator->();
Doc doc;

Doc body;
body << Doc::NewLine();
body << NewLine();
std::vector<Doc> functions;
for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
if ((*it).second.as<PrimFuncNode>()) {
functions.push_back(Print((*it).second));
}
}
body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
body << TIRTextPrinter::PrintSep(functions, NewLine() << NewLine());
doc << Doc::Indent(0, body);
return doc;
}
@@ -451,7 +453,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {

Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body);
doc << "let " << Print(op->var) << " = " << Print(op->value) << NewLine() << Print(op->body);
return doc;
}

@@ -463,14 +465,14 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine()
doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << NewLine()
<< Print(op->body);
return doc;
}
@@ -529,7 +531,7 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}
@@ -542,19 +544,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc doc;
doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", "
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine();
<< PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << NewLine();
if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
doc << ";" << NewLine() << Print(op->body);
}
return doc;
}
@@ -572,9 +574,9 @@ Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
std::vector<Doc> stmts;
Doc seq_doc, doc;
for (Stmt stmt : op->seq) {
seq_doc << Doc::NewLine() << Print(stmt);
seq_doc << NewLine() << Print(stmt);
}
doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}";
doc << " {" << Doc::Indent(2, seq_doc) << NewLine() << "}";
return doc;
}

@@ -657,37 +659,36 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
Doc block_attr_doc;
// print predicate, binding, read/write tensor region, annotations
if (!is_one(op->predicate)) {
block_attr_doc << Doc::NewLine() << "where(" << Print(op->predicate) << ")";
block_attr_doc << NewLine() << "where(" << Print(op->predicate) << ")";
}
for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
block_attr_doc << NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
<< Print(op->iter_values[i]) << ")";
block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
block_attr_doc << NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
block_attr_doc << NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
if (!block_op->annotations.empty()) {
std::vector<Doc> attr_docs;
for (const auto& it : block_op->annotations) {
attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
}
block_attr_doc << Doc::NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", "))
<< "})";
block_attr_doc << NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
}
// print body
Doc body;
body << Doc::NewLine();
body << NewLine();
for (const auto& alloc_buf : block_op->alloc_buffers) {
body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype)
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
<< Print(alloc_buf->shape) << ")" << NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
<< NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
init_block << "with init()";
init_block << PrintBody(block_op->init.value());
body << init_block << Doc::NewLine();
body << init_block << NewLine();
}
body << Print(block_op->body);
doc << Doc::Indent(2, block_attr_doc << body);
@@ -826,7 +827,7 @@ Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
Doc doc;
if (body->IsInstance<SeqStmtNode>()) return Print(body);
doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
doc << " {" << Doc::Indent(2, NewLine() << Print(body)) << NewLine() << "}";
return doc;
}

97 changes: 97 additions & 0 deletions src/printer/tir_text_printer_debug.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir_text_printer.cc
* \brief Printer to print out the IR text format
* that can be parsed by a parser.
*/

#include "tir_text_printer_debug.h"

#include <optional>
#include <string>

namespace tvm {
namespace tir {

std::optional<std::string> span_text(const Span& span) {
if (!span.defined()) {
return std::nullopt;
}

std::string source("main.tir");
if (span->source_name.defined() && span->source_name->name.get()) {
source = span->source_name->name;
}
return source + ":" + std::to_string(span->line) + ":" + std::to_string(span->column);
}

template <typename ObjectPtr>
void add_all_relevant_lines(const std::vector<std::tuple<const ObjectPtr*, size_t>>& data,
size_t current_line, Doc* output) {
ICHECK(output) << "output must be a valid Doc";
for (const auto& item : data) {
if (std::get<1>(item) != current_line - 1) {
// Item is not relevant for this line, skip it
continue;
}

// Print out the item's span info if present
auto text = span_text(std::get<0>(item)->span);
if (text.has_value()) {
*output << *text;
} else {
*output << "missing";
}
*output << ", ";
}
}

Doc TIRTextPrinterDebug::NewLine() {
current_line_ += 1;

if (!show_spans_) {
return TIRTextPrinter::NewLine();
}

Doc output;

output << " [";

add_all_relevant_lines(exprs_by_line_, current_line_, &output);
add_all_relevant_lines(stmts_by_line_, current_line_, &output);

output << "]" << TIRTextPrinter::NewLine();

return output;
}

Doc TIRTextPrinterDebug::VisitStmt(const tvm::tir::Stmt& n) {
stmts_by_line_.push_back(std::make_tuple(n.get(), current_line_));
return TIRTextPrinter::VisitStmt(n);
}

Doc TIRTextPrinterDebug::VisitExpr(const PrimExpr& e) {
exprs_by_line_.push_back(std::make_tuple(e.get(), current_line_));
return TIRTextPrinter::VisitExpr(e);
}

} // namespace tir
} // namespace tvm
70 changes: 70 additions & 0 deletions src/printer/tir_text_printer_debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file text_printer.h
* \brief Printer to print out the unified IR text format
* that can be parsed by a parser.
*/

#ifndef TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
#define TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_

#include <tuple>
#include <vector>

#include "text_printer.h"

namespace tvm {
namespace tir {

class TIRTextPrinterDebug : public TIRTextPrinter {
public:
explicit TIRTextPrinterDebug(bool show_spans)
: TIRTextPrinter(false, &meta_), current_line_(1), show_spans_(show_spans) {}

std::vector<std::tuple<const PrimExprNode*, size_t>> GetExprsByLine() const {
return exprs_by_line_;
}

std::vector<std::tuple<const StmtNode*, size_t>> GetStmtsByLine() const { return stmts_by_line_; }

private:
Doc NewLine() override;

Doc VisitStmt(const tvm::tir::Stmt& n) override;
Doc VisitExpr(const PrimExpr& e) override;

TextMetaDataContext meta_;

// Line that the printer is currently printing
size_t current_line_;

// Whether to include spans relevant to each line before a newline or not
bool show_spans_;

// Record of all stmts and exprs and their corresponding line
std::vector<std::tuple<const StmtNode*, size_t>> stmts_by_line_;
std::vector<std::tuple<const PrimExprNode*, size_t>> exprs_by_line_;
};

} // namespace tir
} // namespace tvm

#endif // TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
84 changes: 49 additions & 35 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
@@ -183,57 +183,63 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, b
InitGlobalContext(dynamic_lookup);
}

void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
}
AddDebugInformation(f, function_);
}

// Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv
void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) {
llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) {
#if TVM_LLVM_VERSION >= 50
ICHECK(!f_llvm->getSubprogram());
llvm::SmallVector<llvm::Metadata*, 4> paramTys;
// Functions in TIR can only return void or an int.
ICHECK(f_llvm->getReturnType() == t_void_ || f_llvm->getReturnType() == t_int_)
<< "Unexpected return type";
auto ret_type_tir = f_llvm->getReturnType() == t_int_ ? DataType::Int(32) : DataType::Void();
llvm::DIType* returnTy =
GetDebugType(GetTypeFromRuntimeDataType(ret_type_tir), f_llvm->getReturnType());
paramTys.push_back(returnTy);
for (size_t i = 0; i < f_llvm->arg_size(); ++i) {
paramTys.push_back(
GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)));

paramTys.push_back(GetDebugType(f->ret_type));
for (const auto& param : f->params) {
paramTys.push_back(GetDebugType(GetType(param)));
}

auto* DIFunctionTy = dbg_info_->di_builder_->createSubroutineType(
dbg_info_->di_builder_->getOrCreateTypeArray(paramTys));

bool local_to_unit = llvm::GlobalValue::isLocalLinkage(f_llvm->getLinkage());
bool local_to_unit = llvm::GlobalVariable::isLocalLinkage(llvm::GlobalValue::InternalLinkage);

// TODO(driazati): determine the IRModule name instead of hardcoding 'main.tir'
#if TVM_LLVM_VERSION >= 80
auto SPFlags =
llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true, /*IsOptimized=*/true);
auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true,
/*IsOptimized=*/true);
auto* DIFunction = dbg_info_->di_builder_->createFunction(
/*Scope=*/dbg_info_->file_, /*Name=*/f_llvm->getName(), /*LinkageName=*/"",
/*Scope=*/dbg_info_->file_, /*Name=*/"main.tir", /*LinkageName=*/"",
/*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy,
/*ScopeLine=*/0, /*Flags=*/llvm::DINode::FlagZero, /*SPFlags=*/SPFlags);
#else
auto* DIFunction = dbg_info_->di_builder_->createFunction(
/*Scope=*/dbg_info_->file_, /*Name=*/f_llvm->getName(), /*LinkageName=*/"",
/*Scope=*/dbg_info_->file_, /*Name=*/"main.tir", /*LinkageName=*/"",
/*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy,
/*isLocalToUnit=*/local_to_unit, /*isDefinition=*/true, /*ScopeLine=*/0,
/*Flags=*/llvm::DINode::FlagPrototyped, /*isOptimized=*/true);
#endif
return DIFunction;
#else
return nullptr;
#endif
}

void CodeGenCPU::AddFunction(const PrimFunc& f) {
#if TVM_LLVM_VERSION >= 50
di_subprogram_ = CreateDebugFunction(f);
#endif
EmitDebugLocation(f->span);
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
std::make_pair(global_symbol.value().operator std::string(), function_));
}
AddDebugInformation(f, function_);
}

ICHECK(DIFunction);
f_llvm->setSubprogram(DIFunction);
ICHECK_EQ(f_llvm->getSubprogram(), DIFunction);
// Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv
void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) {
#if TVM_LLVM_VERSION >= 50
ICHECK(di_subprogram_);
f_llvm->setSubprogram(di_subprogram_);
ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_);

IRBuilder builder(&f_llvm->getEntryBlock());
if (!f_llvm->getEntryBlock().empty()) {
@@ -246,11 +252,11 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) {
auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i));
std::string paramName = "arg" + std::to_string(i + 1);
auto param = dbg_info_->di_builder_->createParameterVariable(
DIFunction, paramName, i + 1, dbg_info_->file_, 0,
di_subprogram_, paramName, i + 1, dbg_info_->file_, 0,
GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)),
/*alwaysPreserve=*/true);
auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca);
auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, DIFunction);
auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, di_subprogram_);
dbg_info_->di_builder_->insertDeclare(paramAlloca, param,
dbg_info_->di_builder_->createExpression(),
llvm::DebugLoc(di_loc), store);
@@ -260,6 +266,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) {
if (!scope) {
return;
}

for (auto& BB : *f_llvm) {
for (auto& I : BB) {
if (I.getDebugLoc()) {
@@ -272,6 +279,9 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) {
#endif
}

llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir) {
return GetDebugType(ty_tir, GetLLVMType(ty_tir));
}
llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) {
if (ty_llvm == t_void_) {
return nullptr;
@@ -541,6 +551,7 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
}

void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
EmitDebugLocation(op);
/*! \brief maintain states that should be guarded when step into compute scope */
struct ComputeScopeStates {
explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {}
@@ -1447,6 +1458,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
}

void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
EmitDebugLocation(op);
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
@@ -1475,6 +1487,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) {
}

void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
EmitDebugLocation(op);
if (op->attr_key == tir::attr::coproc_uop_scope) {
const StringImmNode* value = op->value.as<StringImmNode>();
ICHECK(value != nullptr);
@@ -1517,6 +1530,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
}

void CodeGenCPU::VisitStmt_(const ForNode* op) {
EmitDebugLocation(op);
ICHECK(is_zero(op->min));
if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
CodeGenLLVM::VisitStmt_(op);
2 changes: 2 additions & 0 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
@@ -164,6 +164,7 @@ class CodeGenCPU : public CodeGenLLVM {
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
llvm::DISubprogram* CreateDebugFunction(const PrimFunc& f);
// Context for injection lookup
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
@@ -194,6 +195,7 @@ class CodeGenCPU : public CodeGenLLVM {

// Get the DWARF type corresponding to the LLVM type |ty|. The current API in practice only
// generates |int32|, and |int8*|.
llvm::DIType* GetDebugType(const Type& ty_tir);
llvm::DIType* GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm);
// Adds the DWARF debug information for |function| to |dbg_info_|.
void AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm);
55 changes: 47 additions & 8 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
@@ -298,6 +298,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
}
#endif

EmitDebugLocation(f->span);
if (ret_void) {
builder_->CreateRetVoid();
} else {
@@ -556,6 +557,7 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
//
void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_var, PrimExpr index,
DataType access_dtype) {
EmitDebugLocation(index->span);
if (alias_var_set_.count(buffer_var) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
@@ -663,12 +665,13 @@ std::unique_ptr<CodeGenLLVM::DebugInfo> CodeGenLLVM::CreateDebugInfo(llvm::Modul
debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
#endif
// TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
debug_info->file_ = debug_info->di_builder_->createFile("main.tir", ".");
const int runtime_version = 0;
const bool is_optimized = false;
const char* compiler_flags = "";
debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
llvm::DICompileUnit::DebugEmissionKind::FullDebug,
/* SplitDebugInlining */ true,
/* DebugInfoForProfiling */ true);
/*Lang=*/llvm::dwarf::DW_LANG_C, /*File=*/debug_info->file_, /*Producer=*/"TVM", is_optimized,
compiler_flags, runtime_version);
return debug_info;
}

@@ -789,6 +792,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {

void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride,
const Var& loop_var, const Stmt& body) {
EmitDebugLocation(body->span);
llvm::BasicBlock* pre_block = builder_->GetInsertBlock();
std::string loop_var_name = loop_var->name_hint;
llvm::LLVMContext* ctx = llvm_target_->GetContext();
@@ -802,8 +806,8 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
loop_value->addIncoming(begin, pre_block);
ICHECK(!var_map_.count(loop_var.get()));
var_map_[loop_var.get()] = loop_value;
builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end,
md_very_likely_branch_);
auto lt = CreateLT(loop_var.dtype(), loop_value, end);
builder_->CreateCondBr(lt, for_body, for_end, md_very_likely_branch_);
builder_->SetInsertPoint(for_body);
this->VisitStmt(body);
var_map_.erase(loop_var.get());
@@ -916,6 +920,7 @@ llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const {

void CodeGenLLVM::CreatePrintf(const std::string& format,
llvm::ArrayRef<llvm::Value*> format_args) {
EmitDebugLocation();
llvm::Function* func_printf = module_->getFunction("printf");
if (func_printf == nullptr) {
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, true);
@@ -946,6 +951,7 @@ void CodeGenLLVM::CreatePrintf(const std::string& format,
}

llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) {
EmitDebugLocation();
llvm::Value* level_val = llvm::ConstantInt::get(t_int32_, level);
llvm::Function* builtin =
llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::returnaddress);
@@ -1755,6 +1761,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
}

void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
EmitDebugLocation(op);
DataType value_dtype = op->value.dtype();
Var buffer_var = op->buffer->data;

@@ -1781,6 +1788,7 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
}

void CodeGenLLVM::VisitStmt_(const ForNode* op) {
EmitDebugLocation(op);
ICHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
if (op->kind == ForKind::kUnrolled) {
@@ -1794,6 +1802,7 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
}

void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
EmitDebugLocation(op);
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_);
auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_);
@@ -1808,6 +1817,7 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
}

void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
EmitDebugLocation(op);
llvm::Value* cond = MakeValue(op->condition);
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
@@ -1831,6 +1841,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
}

void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
EmitDebugLocation(op);
auto data = op->data.value();
auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data);
std::string symbol_name = op->buffer_var->name_hint;
@@ -1842,6 +1853,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
}

void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
EmitDebugLocation(op);
ICHECK_EQ(op->extents.size(), 1)
<< "LLVM codegen only supports flat 1-d buffer allocation, but allocation of "
<< op->buffer_var->name_hint << " is " << op->extents << "-d";
@@ -1892,6 +1904,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
}

void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
EmitDebugLocation(op);
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
@@ -1917,11 +1930,14 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
}

void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) {
EmitDebugLocation(op);
// auto a_cu =
With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body);
}

void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
EmitDebugLocation(op);
const VarNode* v = op->var.get();
ICHECK(!var_map_.count(v));
if (v->dtype.is_handle()) {
@@ -1941,12 +1957,35 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
}

void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) {
EmitDebugLocation(op);
for (Stmt stmt : op->seq) {
this->VisitStmt(stmt);
}
}

void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) {
EmitDebugLocation(op);
MakeValue(op->value);
}

void CodeGenLLVM::EmitDebugLocation(const Span& span) {
#if TVM_LLVM_VERSION >= 50
if (di_subprogram_ == nullptr) {
// debug info is not always generated outside of CPU codegen
return;
}
if (!span.defined()) {
VLOG(0) << "Cannot emit debug location for undefined span";
return;
}
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_));
builder_->SetCurrentDebugLocation(loc);
#endif
}

void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); }
void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); }

} // namespace codegen
} // namespace tvm
10 changes: 8 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@
#else
#include <llvm/IR/Operator.h>
#endif
#include <llvm/IR/DebugInfoMetadata.h>
#include <llvm/IR/GlobalValue.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
@@ -70,6 +71,7 @@
#include "../../runtime/thread_storage_scope.h"
#include "../../tir/transforms/ir_utils.h"
#include "codegen_params.h"
#include "llvm_instance.h"

namespace llvm {
class Argument;
@@ -92,8 +94,6 @@ class MDBuilder;
namespace tvm {
namespace codegen {

class LLVMTarget;

using namespace tir;

/*!
@@ -523,6 +523,8 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// debug info for function being compiled
llvm::DISubprogram* di_subprogram_;
// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
@@ -533,6 +535,10 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
const Op& builtin_lookup_param_ = builtin::lookup_param();
const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered();

void EmitDebugLocation();
void EmitDebugLocation(const Span& span);
void EmitDebugLocation(const StmtNode* op);

/*! \brief Helper struct for debug infos. */
struct DebugInfo {
~DebugInfo(); // Because of the std::unique_ptr.
150 changes: 150 additions & 0 deletions src/tir/transforms/install_debug_spans.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file install_debug_spans.cc
* \brief Prints TIR code in memory and replaces all spans in the module with
the location to which the ops would be printed
*/

#include "install_debug_spans.h"

#include <tvm/tir/transform.h>

#include <string>
#include <utility>

#include "../../printer/tir_text_printer_debug.h"

namespace tvm {
namespace tir {

Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt) {
DebugInfoInstaller installer(stmt, name + ".tir");
return installer.VisitStmt(stmt);
}

DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) {
// Determine the line that each stmt/expr will be printed on
tvm::tir::TIRTextPrinterDebug printer(false);

// Fill in the stmts and exprs' line info
auto result = printer.Print(stmt).str();

// Create map of the stmt/expr -> its line number in the output to later
// create new spans for each stmt/expr
const auto& stmts = printer.GetStmtsByLine();
VLOG(0) << "Debug printer found " << stmts.size() << " stmts after printing";
for (const auto& line : stmts) {
stmt_lines_[std::get<0>(line)] = std::get<1>(line);
}

const auto& exprs = printer.GetExprsByLine();
VLOG(0) << "Debug printer found " << exprs.size() << " exprs after printing";
for (const auto& line : exprs) {
expr_lines_[std::get<0>(line)] = std::get<1>(line);
}

// Output the printed TIR to the specified file
VLOG(0) << "Outputting TIR to " << filename;
filename_ = std::move(filename);
std::ofstream out(filename_);
out << result;
out.close();
}

PrimExpr DebugInfoInstaller::VisitExpr(const PrimExpr& expr) {
PrimExpr result = expr;
result = StmtExprMutator::VisitExpr(result);
return result;
}

Stmt DebugInfoInstaller::VisitStmt(const Stmt& stmt) {
Stmt result = stmt;
result = StmtExprMutator::VisitStmt(result);
return result;
}

Span DebugInfoInstaller::MaybeSpan(const StmtNode* op) {
auto entry = stmt_lines_.find(op);
if (entry == stmt_lines_.end()) {
return Span();
} else {
size_t column = 0;
size_t line = entry->second;
return Span(SourceName::Get(filename_), line, line, column, column);
}
}

Span DebugInfoInstaller::MaybeSpan(const PrimExprNode* op) {
auto entry = expr_lines_.find(op);
if (entry == expr_lines_.end()) {
return Span();
} else {
size_t column = 0;
size_t line = entry->second;
return Span(SourceName::Get(filename_), line, line, column, column);
}
}

#define X(TypeName) \
PrimExpr DebugInfoInstaller::VisitExpr_(const TypeName##Node* op) { \
auto new_expr = StmtExprMutator::VisitExpr_(op); \
auto new_type = Downcast<TypeName>(new_expr); \
auto new_node = new_type.CopyOnWrite(); \
new_node->span = MaybeSpan(op); \
return new_type; \
}
TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS
#undef X

#define X(TypeName) \
Stmt DebugInfoInstaller::VisitStmt_(const TypeName##Node* op) { \
Stmt new_stmt = StmtExprMutator::VisitStmt_(op); \
auto new_type = Downcast<TypeName>(new_stmt); \
auto new_node = new_type.CopyOnWrite(); \
new_node->span = MaybeSpan(op); \
return new_type; \
}
TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS
#undef X

namespace transform {

Pass InstallDebugSpans() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
ICHECK(m->functions.size() == 1)
<< "Debug info can only be added to IRModules with a single function";
// There is known to be only 1 function in the module at this point
auto entry = m->functions.begin();
auto name = std::get<0>(*entry)->name_hint;
auto* n = f.CopyOnWrite();

n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body));

return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {});
}

TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans);

} // namespace transform
} // namespace tir
} // namespace tvm
132 changes: 132 additions & 0 deletions src/tir/transforms/install_debug_spans.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file install_debug_spans.h
* \brief Interface of the InstallDebugSpans pass
*/

#ifndef TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_H_
#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_H_

#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <string>
#include <unordered_map>

#ifndef TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_
#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_

#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS \
X(Call) \
X(Add) \
X(Sub) \
X(Mul) \
X(Div) \
X(Mod) \
X(FloorDiv) \
X(FloorMod) \
X(Min) \
X(Max) \
X(EQ) \
X(NE) \
X(LT) \
X(LE) \
X(GT) \
X(GE) \
X(And) \
X(Or) \
X(Reduce) \
X(Cast) \
X(Not) \
X(Select) \
X(Ramp) \
X(Broadcast) \
X(Shuffle) \
X(IntImm) \
X(FloatImm) \
X(StringImm)

#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS \
X(AttrStmt) \
X(IfThenElse) \
X(LetStmt) \
X(For) \
X(While) \
X(Allocate) \
X(AllocateConst) \
X(DeclBuffer) \
X(Store) \
X(BufferStore) \
X(BufferRealize) \
X(AssertStmt) \
X(ProducerStore) \
X(ProducerRealize) \
X(Prefetch) \
X(SeqStmt) \
X(Evaluate) \
X(BlockRealize)

#endif // TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_

namespace tvm {
namespace tir {

/*!
* \brief This Pass prints out the provided 'stmt' through the TIR debug printer
while recording the statements and expressions printed on each line. Running
this pass uses the per-line information to change the Spans attached to each
statement and expression to the source location in the printed TIR. This pass
also writes to a file called '<name>.tir' so the line information used is
saved to disk.
*/
class DebugInfoInstaller : public StmtExprMutator {
public:
static Stmt InstallInfo(const std::string& name, const Stmt& stmt);

PrimExpr VisitExpr(const PrimExpr& expr) override;
Stmt VisitStmt(const Stmt& stmt) override;

protected:
DebugInfoInstaller(const Stmt& stmt, const std::string& filename);

#define X(TypeName) PrimExpr VisitExpr_(const TypeName##Node* op) override;
TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS
#undef X

#define X(TypeName) Stmt VisitStmt_(const TypeName##Node* op) override;
TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS
#undef X

private:
std::unordered_map<const StmtNode*, size_t> stmt_lines_;
std::unordered_map<const PrimExprNode*, size_t> expr_lines_;
std::string filename_;

Span MaybeSpan(const StmtNode* op);
Span MaybeSpan(const PrimExprNode* op);
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_H_
124 changes: 124 additions & 0 deletions tests/python/tir/test_debug_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test line-level debug info for TIR"""
import tvm
import tvm.testing
from tvm import tir
from tvm import relay
from tvm.script import tir as T

from typing import List, Dict
import re


def find_di_locations(source: str) -> Dict[int, int]:
"""
Parse out DILocation references in printed LLVM IR
"""
result = {}

for line in source.splitlines():
m = re.match(r"!(\d+) = !DILocation\(line: (\d+).*", line)
if m:
debug_id, line = m.groups()
result[debug_id] = line

return result


def _module():
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
# A block is an abstraction for computation.
with T.block("B"):
# Define a spatial block iterator and bind it to value i.
vi = T.axis.spatial(8, i)
assert 1 == 0, "Some numbers"
B[vi] = A[vi] + 1.0

return MyModule


def test_tir_debug_info():
"""
Test that Spans are correctly replaced with debug spans that reference
the printed TIR
"""

def find_span(m):
func = next(m.functions.values())
return func.body.block.body.span

module_before = _module()
span_before = find_span(module_before)
assert span_before is None

module_after = tir.transform.InstallDebugSpans()(module_before)
span_after = find_span(module_after)

# Check that the module name has been added and a line number is present
assert span_after.source_name.name == "main.tir"
assert span_after.line == 4


def test_llvm_ir_debug_info():
"""
Check that the right amount of debug locations are present
"""
MyModule = _module()
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}):
runtime_module = tvm.build(MyModule, target="llvm")

source = runtime_module.get_source()

locations = find_di_locations(source)
assert len(locations) == 34


def test_llvm_ir_debug_accuracy():
"""
Check that the debug location on an assert is correct
"""
MyModule = _module()
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}):
runtime_module = tvm.build(MyModule, target="llvm")
source = runtime_module.get_source()
locations = find_di_locations(source)

# Find the 'assert' from MyModule
debug_dir_match = re.search(
r"tail call void %0\(i8\* getelementptr inbounds .* !dbg !(\d+)\n", source
)

# Extract out the debug directive line
directive_idx = debug_dir_match.groups()[0]

# Check that it matches the expected line number (in main.tir)
debug_line_no = int(locations[directive_idx])
assert debug_line_no == 42


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 391b659

Please sign in to comment.