Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][LLVM] Allow void return type from PackedFunc #14958

Merged
merged 2 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
const DataType& r_type,
const int64_t begin, const int64_t end,
bool use_string_lookup) {
PackedCall pc;
std::string func_name = args[0].as<StringImmNode>()->value;
std::string func_name = [&]() {
auto ptr = args[0].as<StringImmNode>();
ICHECK(ptr) << "Expected first argument of tir::Call to be "
<< "a string containing the callee's name, "
<< "but instead contained " << args[0];
return ptr->value;
}();
// call the function
int64_t nargs = end - begin;
ICHECK_GE(nargs, 0);
Expand Down Expand Up @@ -936,27 +941,32 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&

llvm::BasicBlock* end_block = CheckCallSuccess(call);

// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
PackedCall pc = {0};

if (!r_type.is_void()) {
// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
#if TVM_LLVM_VERSION >= 110
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
#else
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
#else
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
#endif
}

pc.end_block = end_block;
return pc;
Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind
* \return The corresponding API type.
*/
inline DataType APIType(DataType t) {
ICHECK(!t.is_void()) << "Cannot pass void type through packed API.";
if (t.is_handle()) return t;
ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
if (t.is_uint() || t.is_int()) return DataType::Int(64);
Expand Down
3 changes: 3 additions & 0 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace tir {

static constexpr const char* kDeviceContextVar = "device_api_context";

namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Expand Down Expand Up @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
bool made_change_{false};
};

} // namespace

inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
namespace tvm {
namespace tir {

namespace {

class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const std::unordered_set<const GlobalVarNode*>& external_methods,
Expand Down Expand Up @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
bool made_change_{false};
};

} // namespace

PrimFunc MakeUnpackedAPI(PrimFunc func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
Expand Down
60 changes: 60 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,5 +1049,65 @@ def subroutine(A_data: T.handle("float32")):
assert arr.numpy()[0] == 42.0


@tvm.testing.requires_llvm
def test_call_packed_returning_void():
"""Allow codegen of PackedFunc calls returning void

The LLVM codegen uses the CallNode's dtype to cast the return type
of the PackedFunc into the appropriate LLVM output type. However,
there is no API type for `DataType::Void()`. When the return type
of a PackedFunc is void, the generated code should not attempt to
read the return value.

While `T.call_packed()` will produce a CallNode with an output
dtype of "int32", the use of other return types is valid in TIR.
This test case uses `T.Call` directly to allow an explicit dtype
for the packed function call.
"""

@T.prim_func
def func():
T.Call(
"void",
tvm.ir.Op.get("tir.tvm_call_packed"),
["dummy_function_name"],
)

# Error occurred during build, as part of
# CodeGenCPU::MakeCallPackedLowered.
built = tvm.build(func, target="llvm")


@tvm.testing.requires_llvm
def test_call_packed_without_string_arg():
"""The first argument to tvm_call_packed must be a string

Even if the invalid TIR is constructed, this should throw an
exception to exit cleanly. Previously, use of
`args[0].as<StringImmNode>()` without a null check resulted in
a segfault during codegen.
"""

@T.prim_func
def func(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "func"})
T.Call("int32", tvm.ir.Op.get("tir.tvm_call_packed"), [A.data])

with pytest.raises(tvm.TVMError):
built = tvm.build(func, target="llvm")


@tvm.testing.requires_llvm
def test_call_extern_returning_void():
"""Like test_call_packed_returning_void, but for call_extern"""

@T.prim_func
def func():
T.func_attr({"global_symbol": "func"})
T.Call("void", tvm.ir.Op.get("tir.call_extern"), ["dummy_function_name"])

built = tvm.build(func, target="llvm")


if __name__ == "__main__":
tvm.testing.main()