diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc index c97070e1bf89..ea61378ccccc 100644 --- a/src/tir/transforms/install_debug_spans.cc +++ b/src/tir/transforms/install_debug_spans.cc @@ -31,6 +31,7 @@ #include #include "../../relay/printer/tir_text_printer_debug.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -128,19 +129,30 @@ TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS namespace transform { Pass InstallDebugSpans() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - ICHECK(m->functions.size() == 1) - << "Debug info can only be added to IRModules with a single function"; - // There is known to be only 1 function in the module at this point - auto entry = m->functions.begin(); - auto name = std::get<0>(*entry)->name_hint; - auto* n = f.CopyOnWrite(); - - n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body)); - - return f; + auto pass_func = [](IRModule mod, PassContext ctx) { + Map external_host_functions; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + auto prim_func = opt.value(); + if (IsHostFunc(prim_func).value_or(false) && + prim_func->GetAttr(tvm::attr::kGlobalSymbol)) { + external_host_functions.Set(gvar, prim_func); + } + } + } + + ICHECK_EQ(external_host_functions.size(), 1) + << "Debug info can only be added to IRModules with a single host function"; + + for (auto [gvar, prim_func] : external_host_functions) { + auto name = prim_func->GetAttr(tvm::attr::kGlobalSymbol).value(); + prim_func.CopyOnWrite()->body = DebugInfoInstaller::InstallInfo(name, prim_func->body); + mod.CopyOnWrite()->Update(gvar, prim_func); + } + + return mod; }; - return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstallDebugSpans", {}); } TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans); diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py index 8ecabbd51a97..d333b43b28f5 100644 --- a/tests/python/tir/test_debug_info.py +++ b/tests/python/tir/test_debug_info.py @@ -46,7 +46,13 @@ class MyModule: @T.prim_func def main(a: T.handle, b: T.handle): # We exchange data between function by handles, which are similar to pointer. - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr( + { + "global_symbol": "main", + "tir.noalias": True, + "target": T.target("llvm"), + } + ) # Create buffer from handles. A = T.match_buffer(a, (8,), dtype="float32") B = T.match_buffer(b, (8,), dtype="float32") @@ -83,6 +89,48 @@ def find_span(m): assert span_after.line == 4 +def test_tir_debug_info_with_subroutine(): + """Like test_tir_debug_info, but with a TIR subroutine + + The current InstallDebugSpans applies to a single PrimFunc. This + test verifies that the existence of device-side subroutines + + """ + + def find_span(m): + func = next(m.functions.values()) + return func.body.block.body.span + + @tvm.script.ir_module + class module_before: + @T.prim_func + def main(a: T.handle, b: T.handle): + T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")}) + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + with T.block("B"): + vi = T.axis.spatial(8, i) + module_before.subroutine(T.address_of(A[vi]), T.address_of(B[vi])) + + @T.prim_func + def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.decl_buffer(1, "float32", data=a_ptr) + B = T.decl_buffer(1, "float32", data=b_ptr) + B[0] = A[1] + 1.0 + + span_before = find_span(module_before) + assert span_before is None + + module_after = tir.transform.InstallDebugSpans()(module_before) + span_after = find_span(module_after) + + # Check that the module name has been added and a line number is present + assert span_after.source_name.name == "main.tir" + assert span_after.line == 4 + + def test_llvm_ir_debug_info(): """ Check that the right amount of debug locations are present