From 65ac46a66cd3f350db4dae04d0e8c93431266ed4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 09:36:51 -0500 Subject: [PATCH 1/2] [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs https://github.com/apache/tvm/pull/14913 and https://github.com/apache/tvm/pull/14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. --- src/tir/transforms/make_packed_api.cc | 3 +++ src/tir/transforms/make_unpacked_api.cc | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index dd9d471c5066..825a8da45b27 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tir { static constexpr const char* kDeviceContextVar = "device_api_context"; +namespace { class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} @@ -176,6 +177,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 82685411f592..bdb3a953e99c 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -40,6 +40,8 @@ namespace tvm { namespace tir { +namespace { + class SubroutineCallRewriter : public StmtExprMutator { public: static Optional Apply(const std::unordered_set& external_methods, @@ -84,6 +86,8 @@ class SubroutineCallRewriter : public StmtExprMutator { bool made_change_{false}; }; +} // namespace + PrimFunc MakeUnpackedAPI(PrimFunc func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. From 4641884f373b01af0509eb0bd418f0af3f608d5e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 25 May 2023 15:12:42 -0500 Subject: [PATCH 2/2] [Codegen][LLVM] Allow void return type from PackedFunc Previously, calling a packed func that returns void would result in an error being raised from `tir::APIType`, as there is no runtime representation of a void type. This commit updates `CodeGenCPU::MakeCallPackedLowered` to only read the return value and type fo a `PackedFunc` when the TIR return type is non-void. --- src/target/llvm/codegen_cpu.cc | 38 +++++++----- src/tir/transforms/ir_utils.h | 1 + .../unittest/test_target_codegen_llvm.py | 60 +++++++++++++++++++ 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f129511e5a17..81af89e42f94 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -879,8 +879,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& const DataType& r_type, const int64_t begin, const int64_t end, bool use_string_lookup) { - PackedCall pc; - std::string func_name = args[0].as()->value; + std::string func_name = [&]() { + auto ptr = args[0].as(); + ICHECK(ptr) << "Expected first argument of tir::Call to be " + << "a string containing the callee's name, " + << "but instead contained " << args[0]; + return ptr->value; + }(); // call the function int64_t nargs = end - begin; ICHECK_GE(nargs, 0); @@ -936,27 +941,32 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::BasicBlock* end_block = CheckCallSuccess(call); - // Load the return value and cast it to the designated type (r_type). - DataType r_api_type = tir::APIType(r_type); - llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); - llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); + PackedCall pc = {0}; + + if (!r_type.is_void()) { + // Load the return value and cast it to the designated type (r_type). + DataType r_api_type = tir::APIType(r_type); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); #elif TVM_LLVM_VERSION >= 80 - llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else - llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif - pc.ret_value = CreateCast(r_api_type, r_type, rvalue); - // Load the return type code. + pc.ret_value = CreateCast(r_api_type, r_type, rvalue); + + // Load the return type code. #if TVM_LLVM_VERSION >= 110 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); #elif TVM_LLVM_VERSION >= 80 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); #else - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); #endif + } pc.end_block = end_block; return pc; diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 59dc95dcd6a0..5ed4b2b12eb8 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -151,6 +151,7 @@ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind * \return The corresponding API type. */ inline DataType APIType(DataType t) { + ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; if (t.is_uint() || t.is_int()) return DataType::Int(64); diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index f4326e6fc53d..6a2f5573b274 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -1049,5 +1049,65 @@ def subroutine(A_data: T.handle("float32")): assert arr.numpy()[0] == 42.0 +@tvm.testing.requires_llvm +def test_call_packed_returning_void(): + """Allow codegen of PackedFunc calls returning void + + The LLVM codegen uses the CallNode's dtype to cast the return type + of the PackedFunc into the appropriate LLVM output type. However, + there is no API type for `DataType::Void()`. When the return type + of a PackedFunc is void, the generated code should not attempt to + read the return value. + + While `T.call_packed()` will produce a CallNode with an output + dtype of "int32", the use of other return types is valid in TIR. + This test case uses `T.Call` directly to allow an explicit dtype + for the packed function call. + """ + + @T.prim_func + def func(): + T.Call( + "void", + tvm.ir.Op.get("tir.tvm_call_packed"), + ["dummy_function_name"], + ) + + # Error occurred during build, as part of + # CodeGenCPU::MakeCallPackedLowered. + built = tvm.build(func, target="llvm") + + +@tvm.testing.requires_llvm +def test_call_packed_without_string_arg(): + """The first argument to tvm_call_packed must be a string + + Even if the invalid TIR is constructed, this should throw an + exception to exit cleanly. Previously, use of + `args[0].as()` without a null check resulted in + a segfault during codegen. + """ + + @T.prim_func + def func(A: T.Buffer(1, "float32")): + T.func_attr({"global_symbol": "func"}) + T.Call("int32", tvm.ir.Op.get("tir.tvm_call_packed"), [A.data]) + + with pytest.raises(tvm.TVMError): + built = tvm.build(func, target="llvm") + + +@tvm.testing.requires_llvm +def test_call_extern_returning_void(): + """Like test_call_packed_returning_void, but for call_extern""" + + @T.prim_func + def func(): + T.func_attr({"global_symbol": "func"}) + T.Call("void", tvm.ir.Op.get("tir.call_extern"), ["dummy_function_name"]) + + built = tvm.build(func, target="llvm") + + if __name__ == "__main__": tvm.testing.main()