Skip to content

Commit

Permalink
[ROCm] Some fixes of ROCm codegen (#16404)
Browse files Browse the repository at this point in the history
- Handle tvm_thread_invariant as no op.
- `llvm.amdgcn.ds.bpermute` requires i32 as its input, but it can handle all 32 bit types
- ocml intrinsics lead to incorrect codegen when used with vectorization, remove it and use llvm intrinsics instead
  • Loading branch information
spectrometerHBH authored Jan 16, 2024
1 parent a7dd32c commit 68be158
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 36 deletions.
2 changes: 2 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::assume())) {
llvm::Value* cond = MakeValue(op->args[0]);
return builder_->CreateAssumption(cond);
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
return MakeValue(op->args[0]);
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
}
Expand Down
87 changes: 52 additions & 35 deletions src/target/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) {
index = self + delta;
index = Select((self & (width - 1)) + delta >= width, self, index);
}
// reinterprete var as int32
bool is_int32 = var.dtype().is_int() && var.dtype().bits() == 32;
PrimExpr source = is_int32 ? var : reinterpret(DataType::Int(32), var);
PrimExpr res = Call(DataType::Int(32), builtin::call_pure_extern(),
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
{StringImm("llvm.amdgcn.ds.bpermute"), index << 2, source});
if (!is_int32) {
res = reinterpret(var.dtype(), res);
}
return res;
}

Expand All @@ -114,73 +120,84 @@ TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);

TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);

TVM_REGISTER_OP("tir.ceil")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);

TVM_REGISTER_OP("tir.round")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);

TVM_REGISTER_OP("tir.nearbyint")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);

TVM_REGISTER_OP("tir.trunc")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);

TVM_REGISTER_OP("tir.fabs")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);

TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);

TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);

TVM_REGISTER_OP("tir.exp10")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.exp10")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
// DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);

TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
// DispatchPureExternOCML);

TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);

TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);

TVM_REGISTER_OP("tir.log2")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);

TVM_REGISTER_OP("tir.log10")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);

TVM_REGISTER_OP("tir.sqrt")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);

TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);

TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.tanh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
// DispatchPureExternOCML);

TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.cosh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

TVM_REGISTER_OP("tir.sinh")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.sinh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

TVM_REGISTER_OP("tir.atan")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.atan")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);

} // namespace llvm
} // namespace codegen
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if ((ty.is_vector()) || !ty.is_int()) return true;
if (ty.is_vector()) return ty.bits() * ty.lanes() != 32;
return ty.bits() != 32;
}))) {
return false;
Expand Down
53 changes: 53 additions & 0 deletions tests/python/codegen/test_target_codegen_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tvm import te
import numpy as np
import unittest
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
ty = te.thread_axis("threadIdx.y")
Expand Down Expand Up @@ -130,9 +131,61 @@ def check_rocm(dtype, n, lanes):
check_rocm("float16", 64, 2)


@tvm.testing.requires_rocm
def test_rocm_warp_shuffle():
@T.prim_func
def func(
A_handle: T.handle,
):
A = T.match_buffer(A_handle, (32,), dtype="float32")

for bx in T.thread_binding(1, thread="blockIdx.x"):
for tx in T.thread_binding(32, thread="threadIdx.x"):
with T.block("test"):
A_local = T.alloc_buffer((1,), "float32", scope="local")
mask = T.alloc_buffer((1,), "uint32", scope="local")
t0 = T.alloc_buffer((1,), "float32", scope="local")

A_local[0] = A[tx]
A_local[0] = T.tvm_warp_shuffle(mask[0], A_local[0], 0, 32, 32)
A[tx] = A_local[0]

mod = tvm.build(func, target="rocm")
dev = tvm.rocm(0)
a = tvm.nd.array(np.random.uniform(size=(32,)).astype("float32"), dev)
mod(a)
tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0])


@tvm.testing.requires_rocm
def test_rocm_vectorized_exp():
@T.prim_func
def func(
A_handle: T.handle,
B_handle: T.handle,
):
A = T.match_buffer(A_handle, (4,), dtype="float32")
B = T.match_buffer(B_handle, (4,), dtype="float32")

for bx in T.thread_binding(1, thread="blockIdx.x"):
for tx in T.thread_binding(1, thread="threadIdx.x"):
with T.block("test"):
for i in T.vectorized(0, 4):
B[i] = T.exp2(A[i])

mod = tvm.build(func, target="rocm")
dev = tvm.rocm(0)
a = tvm.nd.array(np.ones((4,)).astype("float32"), dev)
b = tvm.nd.array(np.zeros((4,)).astype("float32"), dev)
mod(a, b)
tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy()))


if __name__ == "__main__":
test_rocm_cross_thread_reduction()
test_rocm_inf_nan()
test_rocm_reduction_binding()
test_rocm_copy()
test_rocm_vectorize_add()
test_rocm_warp_shuffle()
test_rocm_vectorized_exp()

0 comments on commit 68be158

Please sign in to comment.