From 8cd3f09defad7ba25d971fe64a13929786437e01 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Mon, 15 Jan 2024 17:16:46 -0500 Subject: [PATCH] [ROCm] Some fixes of ROCm codegen - Handle tvm_thread_invariant as no op. - `llvm.amdgcn.ds.bpermute` requires i32 as its input, but it can handle all 32 bit types - ocml intrinsics lead to incorrect codegen when used with vectorization, remove it and use llvm intrinsics instead --- src/target/llvm/codegen_llvm.cc | 2 + src/target/llvm/intrin_rule_rocm.cc | 87 +++++++++++-------- src/tir/transforms/lower_thread_allreduce.cc | 2 +- .../codegen/test_target_codegen_rocm.py | 53 +++++++++++ 4 files changed, 108 insertions(+), 36 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3d4d3def2411..9701a299f1d1 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::assume())) { llvm::Value* cond = MakeValue(op->args[0]); return builder_->CreateAssumption(cond); + } else if (op->op.same_as(builtin::tvm_thread_invariant())) { + return MakeValue(op->args[0]); } else { LOG(FATAL) << "unknown intrinsic " << op->op; } diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index d25126f5d828..0fbfade3354a 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } + // reinterprete var as int32 + bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32; + PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var); PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}); + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source}); + if (!is_int32) { + res = reinterpret(var.dtype(), res); + } return res; } @@ -114,73 +120,84 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); TVM_REGISTER_OP("tir.floor") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); TVM_REGISTER_OP("tir.ceil") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); TVM_REGISTER_OP("tir.round") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); TVM_REGISTER_OP("tir.nearbyint") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_OP("tir.trunc") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); TVM_REGISTER_OP("tir.fabs") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_OP("tir.exp").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.exp").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); TVM_REGISTER_OP("tir.exp2") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -TVM_REGISTER_OP("tir.exp10") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.exp10") +// .set_attr("rocm.FLowerIntrinsic", +// DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>); -TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", +// DispatchPureExternOCML); TVM_REGISTER_OP("tir.fma").set_attr( "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_OP("tir.log").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.log").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); TVM_REGISTER_OP("tir.log2") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); TVM_REGISTER_OP("tir.log10") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); TVM_REGISTER_OP("tir.sqrt") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_OP("tir.pow").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.pow").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); -TVM_REGISTER_OP("tir.tanh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.tanh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", +// DispatchPureExternOCML); -TVM_REGISTER_OP("tir.cos").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.cos").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); -TVM_REGISTER_OP("tir.cosh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.cosh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.sin").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.sin").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); -TVM_REGISTER_OP("tir.sinh") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.sinh") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.atan") - .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); +// TVM_REGISTER_OP("tir.atan") +// .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 1b2e8e9db04a..7094d6adaf3c 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && (std::any_of(types.begin(), types.end(), [](DataType ty) { - if ((ty.is_vector()) || !ty.is_int()) return true; + if (ty.is_vector()) return ty.bits() * ty.lanes() != 32; return ty.bits() != 32; }))) { return false; diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index 3e286f6ebff2..a0990c330f03 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -19,6 +19,7 @@ from tvm import te import numpy as np import unittest +from tvm.script import tir as T tx = te.thread_axis("threadIdx.x") ty = te.thread_axis("threadIdx.y") @@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes): check_rocm("float16", 64, 2) +@tvm.testing.requires_rocm +def test_rocm_warp_shuffle(): + @T.prim_func + def func( + A_handle: T.handle, + ): + A = T.match_buffer(A_handle, (32,), dtype="float32") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + with T.block("test"): + A_local = T.alloc_buffer((1,), "float32", scope="local") + mask = T.alloc_buffer((1,), "uint32", scope="local") + t0 = T.alloc_buffer((1,), "float32", scope="local") + + A_local[0] = A[tx] + A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32) + A[tx] = A_local[0] + + mod = tvm.build(func, target="rocm") + dev = tvm.rocm(0) + a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev) + mod(a) + tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0]) + + +@tvm.testing.requires_rocm +def test_rocm_vectorized_exp(): + @T.prim_func + def func( + A_handle: T.handle, + B_handle: T.handle, + ): + A = T.match_buffer(A_handle, (4,), dtype="float32") + B = T.match_buffer(B_handle, (4,), dtype="float32") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(1, thread="threadIdx.x"): + with T.block("test"): + for i in T.vectorized(0, 4): + B[i] = T.exp2(A[i]) + + mod = tvm.build(func, target="rocm") + dev = tvm.rocm(0) + a = tvm.nd.array(np.ones((4,)).astype("float32"), dev) + b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev) + mod(a, b) + tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy())) + + if __name__ == "__main__": test_rocm_cross_thread_reduction() test_rocm_inf_nan() test_rocm_reduction_binding() test_rocm_copy() test_rocm_vectorize_add() + test_rocm_warp_shuffle() + test_rocm_vectorized_exp()