diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index 45c390df1ddc9..708f4bcaa9c49 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -353,7 +353,10 @@ class LogFatal { #pragma disagnostic push #pragma warning(disable : 4722) #endif - [[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION { GetEntry().Finalize(); } + [[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION { + GetEntry().Finalize(); + throw; + } #ifdef _MSC_VER #pragma disagnostic pop #endif @@ -366,7 +369,7 @@ class LogFatal { this->file_ = file; this->lineno_ = lineno; } - [[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() { + [[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() TVM_THROW_EXCEPTION { InternalError error(file_, lineno_, stream_.str()); #if DMLC_LOG_BEFORE_THROW std::cerr << error.what() << std::endl; @@ -560,15 +563,26 @@ std::unique_ptr LogCheckFormat(const X& x, const Y& y) { return LogCheck##name(x, y); \ } +#if defined(__GNUC__) || defined(__clang__) // GCC and Clang #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wsign-compare" +#elif defined(_MSC_VER) // MSVC +#pragma warning(push) +#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch +#endif + TVM_CHECK_FUNC(_LT, <) TVM_CHECK_FUNC(_GT, >) TVM_CHECK_FUNC(_LE, <=) TVM_CHECK_FUNC(_GE, >=) TVM_CHECK_FUNC(_EQ, ==) TVM_CHECK_FUNC(_NE, !=) + +#if defined(__GNUC__) || defined(__clang__) // GCC and Clang #pragma GCC diagnostic pop +#elif defined(_MSC_VER) // MSVC +#pragma warning(pop) +#endif } // namespace detail diff --git a/src/support/pipe.h b/src/support/pipe.h index 557fe89e46705..4babc5b7c422a 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -77,10 +77,11 @@ class Pipe : public dmlc::Stream { size_t Read(void* ptr, size_t size) final { if (size == 0) return 0; #ifdef _WIN32 - auto fread = [&]() { + auto fread = [&]() -> ssize_t { DWORD nread; - if (!ReadFile(handle_, static_cast(ptr), size, &nread, nullptr)) return -1; - return nread; + if (!ReadFile(handle_, static_cast(ptr), size, &nread, nullptr)) + return static_cast(-1); + return static_cast(nread); }; DWORD nread = static_cast(RetryCallOnEINTR(fread, GetLastErrorCode)); ICHECK_EQ(static_cast(nread), size) << "Read Error: " << GetLastError(); @@ -99,10 +100,11 @@ class Pipe : public dmlc::Stream { void Write(const void* ptr, size_t size) final { if (size == 0) return; #ifdef _WIN32 - auto fwrite = [&]() { + auto fwrite = [&]() -> ssize_t { DWORD nwrite; - if (!WriteFile(handle_, static_cast(ptr), size, &nwrite, nullptr)) return -1; - return nwrite; + if (!WriteFile(handle_, static_cast(ptr), size, &nwrite, nullptr)) + return static_cast(-1); + return static_cast(nwrite); }; DWORD nwrite = static_cast(RetryCallOnEINTR(fwrite, GetLastErrorCode)); ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << GetLastError(); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3d4d3def2411a..9701a299f1d12 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::assume())) { llvm::Value* cond = MakeValue(op->args[0]); return builder_->CreateAssumption(cond); + } else if (op->op.same_as(builtin::tvm_thread_invariant())) { + return MakeValue(op->args[0]); } else { LOG(FATAL) << "unknown intrinsic " << op->op; } diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index d25126f5d8289..0fbfade3354a5 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } + // reinterprete var as int32 + bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32; + PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var); PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}); + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); + if (!is_int32) { + res = reinterpret(var.dtype(), res); + } return res; } @@ -114,73 +120,84 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); TVM_REGISTER_OP("tir.floor") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); TVM_REGISTER_OP("tir.ceil") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); TVM_REGISTER_OP("tir.round") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); TVM_REGISTER_OP("tir.nearbyint") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tir.trunc") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); TVM_REGISTER_OP("tir.fabs") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_OP("tir.exp").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.exp").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); TVM_REGISTER_OP("tir.exp2") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -TVM_REGISTER_OP("tir.exp10") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.exp10") +// .set_attr("rocm.FLowerIntrinsic", +// DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>); -TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", +// DispatchPureExternOCML); TVM_REGISTER_OP("tir.fma").set_attr( "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_OP("tir.log").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.log").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); TVM_REGISTER_OP("tir.log2") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); TVM_REGISTER_OP("tir.log10") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); TVM_REGISTER_OP("tir.sqrt") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_OP("tir.pow").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.pow").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); -TVM_REGISTER_OP("tir.tanh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.tanh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", +// DispatchPureExternOCML); -TVM_REGISTER_OP("tir.cos").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.cos").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); -TVM_REGISTER_OP("tir.cosh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.cosh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.sin").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.sin").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); -TVM_REGISTER_OP("tir.sinh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.sinh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.atan") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.atan") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 8f1aa8063bd04..aca504b94b98e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -509,6 +509,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class); ICHECK(var_map_.count(buffer_node)); return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); + } else if (op->op.same_as(builtin::tvm_thread_invariant())) { + return MakeValue(op->args[0]); } else { LOG(FATAL) << "Unresolved call " << op->op; } diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index ffef425c0e41a..e5f869de1718c 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -82,6 +82,9 @@ TVM_REGISTER_OP("tir.fabs") TVM_REGISTER_OP("tir.exp").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tir.exp2") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); + TVM_REGISTER_OP("tir.sin").set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 1b2e8e9db04af..7094d6adaf3cb 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) { - if ((ty.is_vector()) || !ty.is_int()) return true; + if (ty.is_vector()) return ty.bits() * ty.lanes() != 32; return ty.bits() != 32; }))) { return false; diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index c79b9c1f93996..1598d409c5d80 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -662,6 +662,11 @@ namespace transform { Pass MergeSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + // disable this pass for Vulkan + auto target = Target::Current(true); + if (target.defined() && target->kind->name == "vulkan") { + return f; + } auto* n = f.CopyOnWrite(); n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem); return f; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 70f325e4a21ee..7f48a53708fe9 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -1705,8 +1705,13 @@ namespace transform { Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + // disable merge_static_smem for Vulkan + auto target = Target::Current(true); + if (target.defined() && target->kind->name == "vulkan") { + merge_static_smem = false; + } auto* n = f.CopyOnWrite(); - n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem); + n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, merge_static_smem); // Parameters may not be rewritten, but internal allocations may. // Vectorization of AllocateConst is currently disabled, as it has // indexing issues for types that include padding (e.g. int8x3 diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index 3e286f6ebff20..a0990c330f03c 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -19,6 +19,7 @@ from tvm import te import numpy as np import unittest +from tvm.script import tir as T tx = te.thread_axis("threadIdx.x") ty = te.thread_axis("threadIdx.y") @@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes): check_rocm("float16", 64, 2) +@tvm.testing.requires_rocm +def test_rocm_warp_shuffle(): + @T.prim_func + def func( + A_handle: T.handle, + ): + A = T.match_buffer(A_handle, (32,), dtype="float32") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("test"): + A_local = T.alloc_buffer((1,), "float32", scope="local") + mask = T.alloc_buffer((1,), "uint32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") + + A_local[0] = A[tx] + A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32) + A[tx] = A_local[0] + + mod = tvm.build(func, target="rocm") + dev = tvm.rocm(0) + a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev) + mod(a) + tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0]) + + +@tvm.testing.requires_rocm +def test_rocm_vectorized_exp(): + @T.prim_func + def func( + A_handle: T.handle, + B_handle: T.handle, + ): + A = T.match_buffer(A_handle, (4,), dtype="float32") + B = T.match_buffer(B_handle, (4,), dtype="float32") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(1, thread="threadIdx.x"): + with T.block("test"): + for i in T.vectorized(0, 4): + B[i] = T.exp2(A[i]) + + mod = tvm.build(func, target="rocm") + dev = tvm.rocm(0) + a = tvm.nd.array(np.ones((4,)).astype("float32"), dev) + b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev) + mod(a, b) + tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy())) + + if __name__ == "__main__": test_rocm_cross_thread_reduction() test_rocm_inf_nan() test_rocm_reduction_binding() test_rocm_copy() test_rocm_vectorize_add() + test_rocm_warp_shuffle() + test_rocm_vectorized_exp()