From 61c44f98fc8516405c7a965253a49cc0f27c7977 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 2 May 2024 16:01:13 +0000 Subject: [PATCH] [SVE] Add get_active_lane_mask builtin Adds a `get_active_lane_mask` builtin and lowering to `llvm.get.active.lane.mask` intrinsic. This will be used in subsequent patches for expressing predicated buffer loads/stores in TIR. Further information can be found in the [RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication). Co-authored-by: Elen Kalda Co-authored-by: Neil Hickey Change-Id: Id9d65f9f11503ad35dd0b3db4bfc81249a76f701 --- include/tvm/tir/builtin.h | 8 +++++++ python/tvm/script/ir_builder/tir/ir.py | 2 ++ python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 21 +++++++++++++++++++ src/target/llvm/codegen_llvm.cc | 5 +++++ src/tir/op/builtin.cc | 7 +++++++ .../codegen/test_target_codegen_aarch64.py | 20 ++++++++++++++++++ 7 files changed, 64 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 10e5b462d1d1..5836eb8ea93a 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -915,6 +915,14 @@ TVM_DLL const Op& anylist_setitem_call_cpacked(); */ TVM_DLL const Op& vscale(); +/*! + * \brief Calculate a predicate mask given an upper bound (limit) and a current value (base). + * + * It will be lowered to the llvm.get.active.lane.mask intrinsic. + * (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics) + */ +TVM_DLL const Op& get_active_lane_mask(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c04ac780c9e6..5a0a564a2ab5 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1903,6 +1903,7 @@ def wrapped(*args, **kwargs): vectorlow = _dtype_forward(_tir_op.vectorlow) vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) +get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask) broadcast = Broadcast @@ -2219,4 +2220,5 @@ def wrapped(*args, **kwargs): "CommReducer", "Range", "vscale", + "get_active_lane_mask", ] diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1723804388b9..24ba4ccd2e58 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,7 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic -from .op import vscale +from .op import vscale, get_active_lane_mask from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 6b72e63f2990..db52bec598b1 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3349,6 +3349,27 @@ def vscale(): return call_intrin("int32", "tir.vscale") +def get_active_lane_mask(dtype, base, limit): + """ + Calculate a predicate mask given an upper bound (limit) and a current value (base). + + It will be lowered to the llvm.get.active.lane.mask intrinsic. + (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics) + + Parameters + ---------- + dtype : str + The data type of the result. + + base : PrimExpr + An expression reprsenting the base. + + limit : PrimExpr + An expression representing the limit. + """ + return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 95512a00a77c..6566bb4291d8 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1478,6 +1478,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { llvm::Intrinsic::ID id = llvm::Intrinsic::vscale; llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {}); return builder_->CreateCall(f); + } else if (op->op.same_as(builtin::get_active_lane_mask())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask; + llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype), + {builder_->getInt32Ty(), builder_->getInt32Ty()}); + return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])}); #endif } else { LOG(FATAL) << "unknown intrinsic " << op->op; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fbe31c890dad..cf82eb07edf2 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -397,6 +397,13 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + +TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask) + .set_num_inputs(2) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 8f22ba5b73ed..452638beda0a 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -680,5 +680,25 @@ def check_correct_assembly(dtype): check_correct_assembly(dtype=dtype) +@pytest.mark.skipif( + llvm_version_major() < 11, + reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM", +) +def test_get_active_lane_mask(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + + @T.prim_func + def before(a: T.handle): + A = T.match_buffer(a, (30,), "int1") + for i in range(T.ceildiv(30, T.vscale() * 4)): + A[i : i + T.vscale() * 4] = T.get_active_lane_mask("int1xvscalex4", i, 30) + + with tvm.target.Target(target): + out = tvm.build(before) + + ll = out.get_source("ll") + assert "get.active.lane.mask" in ll + + if __name__ == "__main__": tvm.testing.main()