Skip to content

Commit

Permalink
[Codegen][Metal] Support metal warp-level primitive (#15401)
Browse files Browse the repository at this point in the history
This PR introduces the warp-level shuffle primitives used in Metal
Shading Language, and uses them in the implementation of allreduce
lowering.

The introduced primitives are:
* `simd_shuffle`,
* `simd_shuffle_up`,
* `simd_shuffle_down`.

See section 6.9.2 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
for details.

The correctness are validated by `test_allreduce_cuda` with the backend
changed to Metal. Given we do not have Metal CI tests, the correctness
is checked only locally.

Given the Metal shuffle primitives do not support (or need) masking,
the pass LowerThreadAllreduce is updated to support such backend
which does not have masks. One unit test for metal is added to ensure
that no mask is used.
  • Loading branch information
MasterJH5574 authored Jul 26, 2023
1 parent 304aa1e commit 22ec541
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 11 deletions.
53 changes: 53 additions & 0 deletions src/target/source/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@ namespace codegen {
namespace intrin {
using tir::FLowerIntrinsic;

struct MetalWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.metal.simd_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.metal.simd_shuffle_up");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.metal.simd_shuffle_down");
}
}
};

template <typename T>
static PrimExpr DispatchMetalShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> metal_args{{call->args[1], call->args[2]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), metal_args);
}

TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

Expand Down Expand Up @@ -95,6 +117,37 @@ TVM_REGISTER_OP("tir.cosh")

TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);

TVM_REGISTER_OP("tir.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);

// Register low-level builtin ops.
TVM_REGISTER_OP("tir.metal.simd_shuffle")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.metal.simd_shuffle_up")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_up")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.metal.simd_shuffle_down")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_down")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

} // namespace intrin
} // namespace codegen
} // namespace tvm
35 changes: 24 additions & 11 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// The mask for this reducer, as this reducer may sit inside
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
{
seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices));
Optional<Buffer> mask_buffer;
if (need_warp_shuffle_mask_) {
mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices));
// Push the buffer description. Later this will have an
// allocation built for it.
local_bufs.push_back(mask_buffer);
local_bufs.push_back(mask_buffer.value());
}

// Emit reductions within a warp.
Expand Down Expand Up @@ -698,9 +699,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}

// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) {
PrimExpr WarpShuffle(const Op& op, Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0};
PrimExpr mask = BufferLoad(mask_buffer, indices);
PrimExpr mask;
if (mask_buffer.defined()) {
mask = BufferLoad(mask_buffer.value(), indices);
} else {
mask = IntImm(DataType::Int(32), 0);
}
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
Expand All @@ -709,11 +716,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// Check if we can use warp level reduction.
//
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal).
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) const {
// Only cuda target supports warp reductions.
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
return false;
}

need_warp_shuffle_mask_ = target_->kind->name != "metal";

// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
Expand Down Expand Up @@ -745,7 +756,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// whether reduce_extent and group_extent are valid for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else { // target_->kind->name == "cuda"
} else {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
Expand All @@ -769,6 +780,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
int warp_size_{1};
// The maximum number of threads of the device. "-1" denotes unknown.
int max_num_threads_{-1};
// A boolean indicating if the target supports warp-level masking.
bool need_warp_shuffle_mask_;

// surrounding scope of thread extent.
std::vector<const AttrStmtNode*> thread_extents_;
Expand Down
103 changes: 103 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,5 +702,108 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")):
B_1[threadIdx_y] = red_result_1[threadIdx_y]


class TestMetalNoMask(BaseCompare):
@T.prim_func
def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "metal",
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
cross_thread_B = T.allocate([1], "float32", "local")
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
A_1 = T.Buffer((256,), data=A.data)
T.tvm_thread_allreduce(
T.uint32(1),
A_1[threadIdx_y * 128 + threadIdx_x],
T.bool(True),
cross_thread_B_1[0],
threadIdx_x,
)
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = cross_thread_B_1[0]

@T.prim_func
def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")):
T.func_attr(
{
"target": T.target(
{
"kind": "metal",
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"host": "llvm",
}
),
}
)
blockIdx_x = T.launch_thread("blockIdx.x", 1)
red_result = T.allocate([2], "float32", "shared")
T.attr(red_result, "volatile_scope", 1)
threadIdx_z = T.launch_thread("threadIdx.z", 1)
threadIdx_y = T.launch_thread("threadIdx.y", 2)
threadIdx_x = T.launch_thread("threadIdx.x", 128)
red_result_1 = T.Buffer((2,), data=red_result, scope="shared")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
red_buf0 = T.allocate([1], "float32", "local")
t0 = T.allocate([1], "float32", "local")
red_buf0_1 = T.allocate([1], "float32", "local")
t0_1 = T.allocate([1], "float32", "local")
red_buf_staging = T.allocate([8], "float32", "shared")
red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local")
A_1 = T.Buffer((256,), data=A.data)
red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x]
t0_2 = T.Buffer((1,), data=t0_1, scope="local")
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 16, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 8, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 4, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 2, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 1, 32, 32)
red_buf0_2[0] = red_buf0_2[0] + t0_2[0]
red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared")
if threadIdx_x % 32 == 0:
red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0]
T.tvm_storage_sync("shared")
red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local")
if threadIdx_x < 4:
red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x]
t0_3 = T.Buffer((1,), data=t0, scope="local")
t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 2, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 1, 32, 32)
red_buf0_3[0] = red_buf0_3[0] + t0_3[0]
if threadIdx_x == 0:
red_result_1[threadIdx_y] = red_buf0_3[0]
T.tvm_storage_sync("shared")
if threadIdx_x == 0:
B_1 = T.Buffer((2,), data=B.data)
B_1[threadIdx_y] = red_result_1[threadIdx_y]


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 22ec541

Please sign in to comment.