Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Restrict tir.transform.InstallDebugSpans to host functions #14943

Merged
merged 4 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ class TargetNode : public Object {
/*! \return The device type for this target */
TVM_DLL int GetTargetDeviceType() const;

/*!
* \brief Check if the target contains a key
*
* \param query_key The string name of the key to be checked
*
* \return True if the target's `TargetNode::keys` contains the
* specified key, False otherwise.
*/
TVM_DLL bool HasKey(const std::string& query_key) const;

/*!
* \brief Returns a human readable representation of \p Target which includes all fields,
* especially the host. Useful for diagnostic messages and debugging.
Expand Down
5 changes: 5 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,11 @@ int TargetNode::GetTargetDeviceType() const {
return kind->default_device_type;
}

bool TargetNode::HasKey(const std::string& query_key) const {
return std::any_of(keys.begin(), keys.end(),
[&query_key](const auto& key) { return key == query_key; });
}

String TargetNode::ToDebugString() const {
std::ostringstream os;
os << "Target(";
Expand Down
36 changes: 24 additions & 12 deletions src/tir/transforms/install_debug_spans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <utility>

#include "../../relay/printer/tir_text_printer_debug.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -128,19 +129,30 @@ TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS
namespace transform {

Pass InstallDebugSpans() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
ICHECK(m->functions.size() == 1)
<< "Debug info can only be added to IRModules with a single function";
// There is known to be only 1 function in the module at this point
auto entry = m->functions.begin();
auto name = std::get<0>(*entry)->name_hint;
auto* n = f.CopyOnWrite();

n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body));

return f;
auto pass_func = [](IRModule mod, PassContext ctx) {
Map<GlobalVar, PrimFunc> external_host_functions;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (IsHostFunc(prim_func).value_or(false) &&
prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
external_host_functions.Set(gvar, prim_func);
}
}
}

ICHECK_EQ(external_host_functions.size(), 1)
<< "Debug info can only be added to IRModules with a single host function";

for (auto [gvar, prim_func] : external_host_functions) {
auto name = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).value();
prim_func.CopyOnWrite()->body = DebugInfoInstaller::InstallInfo(name, prim_func->body);
mod.CopyOnWrite()->Update(gvar, prim_func);
}

return mod;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {});
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InstallDebugSpans", {});
}

TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans);
Expand Down
10 changes: 10 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,16 @@ std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
return std::pair<int32_t, int32_t>(0, 0);
}

std::optional<bool> IsHostFunc(const PrimFunc& func) {
if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) {
return true;
} else if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
return target.value()->HasKey("cpu");
} else {
return std::nullopt;
}
}

namespace transform {
Pass ConvertSSA() {
auto pass_func = [](IRModule mod, PassContext ctx) {
Expand Down
12 changes: 12 additions & 0 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/tir/op.h>

#include <limits>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -351,6 +352,17 @@ CollectStorageAlignAnnotation(const Stmt& body);
std::pair<int32_t, int32_t> GetWmmaFragmentDimSize(const std::string& shape_str,
const std::string& scope);

/*! \brief Check if a PrimFunc is a host function
*
* \param func The function to be inspected
*
* \return True if the function is known to run on the host, false if
* the function is known to run on the device. If it cannot be
* determined (e.g. a function without a tvm::attr::kTarget
* attribute), returns std::nullopt.
*/
std::optional<bool> IsHostFunc(const PrimFunc& func);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_
3 changes: 3 additions & 0 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace tir {

static constexpr const char* kDeviceContextVar = "device_api_context";

namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Expand Down Expand Up @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
bool made_change_{false};
};

} // namespace

inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
namespace tvm {
namespace tir {

namespace {

class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const std::unordered_set<const GlobalVarNode*>& external_methods,
Expand Down Expand Up @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
bool made_change_{false};
};

} // namespace

PrimFunc MakeUnpackedAPI(PrimFunc func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
Expand Down
50 changes: 49 additions & 1 deletion tests/python/tir/test_debug_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
T.func_attr(
{
"global_symbol": "main",
"tir.noalias": True,
"target": T.target("llvm"),
}
)
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
Expand Down Expand Up @@ -83,6 +89,48 @@ def find_span(m):
assert span_after.line == 4


def test_tir_debug_info_with_subroutine():
"""Like test_tir_debug_info, but with a TIR subroutine

The current InstallDebugSpans applies to a single PrimFunc. This
test verifies that the existence of device-side subroutines

"""

def find_span(m):
func = next(m.functions.values())
return func.body.block.body.span

@tvm.script.ir_module
class module_before:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")})
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
with T.block("B"):
vi = T.axis.spatial(8, i)
module_before.subroutine(T.address_of(A[vi]), T.address_of(B[vi]))

@T.prim_func
def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.decl_buffer(1, "float32", data=a_ptr)
B = T.decl_buffer(1, "float32", data=b_ptr)
B[0] = A[1] + 1.0

span_before = find_span(module_before)
assert span_before is None

module_after = tir.transform.InstallDebugSpans()(module_before)
span_after = find_span(module_after)

# Check that the module name has been added and a line number is present
assert span_after.source_name.name == "main.tir"
assert span_after.line == 4


def test_llvm_ir_debug_info():
"""
Check that the right amount of debug locations are present
Expand Down