diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index dd924b925596..cc83eb1462c6 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -30,6 +30,28 @@ namespace codegen { namespace intrin { using tir::FLowerIntrinsic; +struct MetalWarpIntrinsic { + const Op operator()(DataType t, const Op& orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.metal.simd_shuffle"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.metal.simd_shuffle_up"); + } else { + ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.metal.simd_shuffle_down"); + } + } +}; + +template +static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { + const CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + Array metal_args{{call->args[1], call->args[2]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); +} + TVM_REGISTER_OP("tir.floor") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); @@ -95,6 +117,37 @@ TVM_REGISTER_OP("tir.cosh") TVM_REGISTER_OP("tir.erf").set_attr("metal.FLowerIntrinsic", DispatchFastErf); +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") + .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") + .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); + +// Register low-level builtin ops. +TVM_REGISTER_OP("tir.metal.simd_shuffle") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") + .set_attr("TGlobalSymbol", "simd_shuffle") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.metal.simd_shuffle_up") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") + .set_attr("TGlobalSymbol", "simd_shuffle_up") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.metal.simd_shuffle_down") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .set_attr("TGlobalSymbol", "simd_shuffle_down") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 438dccff0bdb..fba62a0c18ac 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -476,12 +476,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The mask for this reducer, as this reducer may sit inside // a divergent control flow. Here it uses a variable to cache the current // active channels. - Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); - { - seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices)); + Optional mask_buffer; + if (need_warp_shuffle_mask_) { + mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); + seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices)); // Push the buffer description. Later this will have an // allocation built for it. - local_bufs.push_back(mask_buffer); + local_bufs.push_back(mask_buffer.value()); } // Emit reductions within a warp. @@ -698,9 +699,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) { + PrimExpr WarpShuffle(const Op& op, Optional mask_buffer, PrimExpr val, + PrimExpr delta_or_lane) { Array indices = {0}; - PrimExpr mask = BufferLoad(mask_buffer, indices); + PrimExpr mask; + if (mask_buffer.defined()) { + mask = BufferLoad(mask_buffer.value(), indices); + } else { + mask = IntImm(DataType::Int(32), 0); + } PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, delta_or_lane, width, width}; return Call(val.dtype(), op, args); @@ -709,11 +716,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Check if we can use warp level reduction. // // Note: The ROCm backend will only have warp reductions for now. - // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). + // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal). bool IsWarpReduction(const std::vector& types, int group_extent, int reduce_extent, - int contiguous_reduce_extent) const { - // Only cuda target supports warp reductions. - if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false; + int contiguous_reduce_extent) { + if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && + (target_->kind->name != "metal")) { + return false; + } + + need_warp_shuffle_mask_ = target_->kind->name != "metal"; // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && @@ -745,7 +756,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // whether reduce_extent and group_extent are valid for warp reduction. if (target_->kind->name == "rocm") { return reduce_extent == warp_size_; - } else { // target_->kind->name == "cuda" + } else { if (reduce_extent == 1) { return false; // no need to warp reduce } else { @@ -769,6 +780,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { int warp_size_{1}; // The maximum number of threads of the device. "-1" denotes unknown. int max_num_threads_{-1}; + // A boolean indicating if the target supports warp-level masking. + bool need_warp_shuffle_mask_; // surrounding scope of thread extent. std::vector thread_extents_; diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index 9d53b1f9dfb5..f797d35d47ca 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -702,5 +702,108 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): B_1[threadIdx_y] = red_result_1[threadIdx_y] +class TestMetalNoMask(BaseCompare): + @T.prim_func + def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): + T.func_attr( + { + "target": T.target( + { + "kind": "metal", + "max_threads_per_block": 1024, + "thread_warp_size": 32, + "host": "llvm", + } + ), + } + ) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_z = T.launch_thread("threadIdx.z", 1) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((256,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 128 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((2,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): + T.func_attr( + { + "target": T.target( + { + "kind": "metal", + "max_threads_per_block": 1024, + "thread_warp_size": 32, + "host": "llvm", + } + ), + } + ) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + red_result = T.allocate([2], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) + threadIdx_z = T.launch_thread("threadIdx.z", 1) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + red_result_1 = T.Buffer((2,), data=red_result, scope="shared") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + red_buf0 = T.allocate([1], "float32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1 = T.allocate([1], "float32", "local") + t0_1 = T.allocate([1], "float32", "local") + red_buf_staging = T.allocate([8], "float32", "shared") + red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + A_1 = T.Buffer((256,), data=A.data) + red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x] + t0_2 = T.Buffer((1,), data=t0_1, scope="local") + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 16, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 8, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 4, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 2, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 1, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared") + if threadIdx_x % 32 == 0: + red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] + T.tvm_storage_sync("shared") + red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") + if threadIdx_x < 4: + red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x] + t0_3 = T.Buffer((1,), data=t0, scope="local") + t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 2, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 1, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + if threadIdx_x == 0: + red_result_1[threadIdx_y] = red_buf0_3[0] + T.tvm_storage_sync("shared") + if threadIdx_x == 0: + B_1 = T.Buffer((2,), data=B.data) + B_1[threadIdx_y] = red_result_1[threadIdx_y] + + if __name__ == "__main__": tvm.testing.main()