Skip to content

Commit

Permalink
[TIR] Restrict tir.transform.InstallDebugSpans to host functions (apa…
Browse files Browse the repository at this point in the history
…che#14943)

* [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI

PRs apache#14913 and
apache#14914 made analogous changes to
`MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls.
Both PRs introduced the same symbol,
`tvm::tir::SubroutineCallRewriter`, a local utility to update internal
calls to a modified function.  While each PR passed CI individually,
and was therefore able to merge, having both changes caused a
duplicate symbol.

This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place
their local utilities into anonymous namespaces, avoiding the
conflict.

* [Target] Added utility method TargetNode::HasKey()

This utility method makes it easier to determine if a target contains
a specific key.

* [TIR] Added utility method tvm::tir::IsHostFunc(const PrimFunc&)

For modules that contain both host and device functions, this utility
function checks whether a given PrimFunc is a host function, based on
the target annotation.

* [TIR] Restrict InstallDebugSpans to host functions

Previously, the `tir.InstallDebugSpans` pass required the module to
contain only a single PrimFunc.  This commit relaxes the requirement,
to require a single host-side PrimFunc, and to ignore any other
device-side functions.
  • Loading branch information
Lunderberg authored and mei-ye committed Jun 1, 2023
1 parent a5026bf commit e56c96b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
36 changes: 24 additions & 12 deletions src/tir/transforms/install_debug_spans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <utility>

#include "../../relay/printer/tir_text_printer_debug.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -128,19 +129,30 @@ TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS
namespace transform {

Pass InstallDebugSpans() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
ICHECK(m->functions.size() == 1)
<< "Debug info can only be added to IRModules with a single function";
// There is known to be only 1 function in the module at this point
auto entry = m->functions.begin();
auto name = std::get<0>(*entry)->name_hint;
auto* n = f.CopyOnWrite();

n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body));

return f;
auto pass_func = [](IRModule mod, PassContext ctx) {
Map<GlobalVar, PrimFunc> external_host_functions;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (IsHostFunc(prim_func).value_or(false) &&
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
external_host_functions.Set(gvar, prim_func);
}
}
}

ICHECK_EQ(external_host_functions.size(), 1)
<< "Debug info can only be added to IRModules with a single host function";

for (auto [gvar, prim_func] : external_host_functions) {
auto name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
prim_func.CopyOnWrite()->body = DebugInfoInstaller::InstallInfo(name, prim_func->body);
mod.CopyOnWrite()->Update(gvar, prim_func);
}

return mod;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {});
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstallDebugSpans", {});
}

TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans);
Expand Down
50 changes: 49 additions & 1 deletion tests/python/tir/test_debug_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
T.func_attr(
{
"global_symbol": "main",
"tir.noalias": True,
"target": T.target("llvm"),
}
)
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
Expand Down Expand Up @@ -83,6 +89,48 @@ def find_span(m):
assert span_after.line == 4


def test_tir_debug_info_with_subroutine():
"""Like test_tir_debug_info, but with a TIR subroutine
The current InstallDebugSpans applies to a single PrimFunc. This
test verifies that the existence of device-side subroutines
"""

def find_span(m):
func = next(m.functions.values())
return func.body.block.body.span

@tvm.script.ir_module
class module_before:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")})
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
with T.block("B"):
vi = T.axis.spatial(8, i)
module_before.subroutine(T.address_of(A[vi]), T.address_of(B[vi]))

@T.prim_func
def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.decl_buffer(1, "float32", data=a_ptr)
B = T.decl_buffer(1, "float32", data=b_ptr)
B[0] = A[1] + 1.0

span_before = find_span(module_before)
assert span_before is None

module_after = tir.transform.InstallDebugSpans()(module_before)
span_after = find_span(module_after)

# Check that the module name has been added and a line number is present
assert span_after.source_name.name == "main.tir"
assert span_after.line == 4


def test_llvm_ir_debug_info():
"""
Check that the right amount of debug locations are present
Expand Down

0 comments on commit e56c96b

Please sign in to comment.