Skip to content

Commit

Permalink
[SVE] Add get_active_lane_mask builtin
Browse files Browse the repository at this point in the history
Adds a `get_active_lane_mask` builtin and lowering to
`llvm.get.active.lane.mask` intrinsic. This will be used in subsequent
patches for expressing predicated buffer loads/stores in TIR. Further
information can be found in the [RFC](https://github.com/apache/tvm-rfcs/blob/main/rfcs/0104-scalable-vectors-in-tir.md#predication).

Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>

Change-Id: Id9d65f9f11503ad35dd0b3db4bfc81249a76f701
  • Loading branch information
lhutton1 committed May 2, 2024
1 parent a320b63 commit 61c44f9
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 1 deletion.
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,14 @@ TVM_DLL const Op& anylist_setitem_call_cpacked();
*/
TVM_DLL const Op& vscale();

/*!
* \brief Calculate a predicate mask given an upper bound (limit) and a current value (base).
*
* It will be lowered to the llvm.get.active.lane.mask intrinsic.
* (https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
*/
TVM_DLL const Op& get_active_lane_mask();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,7 @@ def wrapped(*args, **kwargs):
vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask)


broadcast = Broadcast
Expand Down Expand Up @@ -2219,4 +2220,5 @@ def wrapped(*args, **kwargs):
"CommReducer",
"Range",
"vscale",
"get_active_lane_mask",
]
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale
from .op import vscale, get_active_lane_mask
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3349,6 +3349,27 @@ def vscale():
return call_intrin("int32", "tir.vscale")


def get_active_lane_mask(dtype, base, limit):
"""
Calculate a predicate mask given an upper bound (limit) and a current value (base).
It will be lowered to the llvm.get.active.lane.mask intrinsic.
(https://llvm.org/docs/LangRef.html#llvm-get-active-lane-mask-intrinsics)
Parameters
----------
dtype : str
The data type of the result.
base : PrimExpr
An expression reprsenting the base.
limit : PrimExpr
An expression representing the limit.
"""
return call_intrin(dtype, "tir.get_active_lane_mask", base, limit)


# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
Expand Down
5 changes: 5 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
llvm::Intrinsic::ID id = llvm::Intrinsic::vscale;
llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {});
return builder_->CreateCall(f);
} else if (op->op.same_as(builtin::get_active_lane_mask())) {
llvm::Intrinsic::ID id = llvm::Intrinsic::get_active_lane_mask;
llvm::Function* f = GetIntrinsicDecl(id, DTypeToLLVMType(op->dtype),
{builder_->getInt32Ty(), builder_->getInt32Ty()});
return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])});
#endif
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
Expand Down
7 changes: 7 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked)

TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));

} // namespace builtin
} // namespace tir
} // namespace tvm
20 changes: 20 additions & 0 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,5 +680,25 @@ def check_correct_assembly(dtype):
check_correct_assembly(dtype=dtype)


@pytest.mark.skipif(
llvm_version_major() < 11,
reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM",
)
def test_get_active_lane_mask():
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"

@T.prim_func
def before(a: T.handle):
A = T.match_buffer(a, (30,), "int1")
for i in range(T.ceildiv(30, T.vscale() * 4)):
A[i : i + T.vscale() * 4] = T.get_active_lane_mask("int1xvscalex4", i, 30)

with tvm.target.Target(target):
out = tvm.build(before)

ll = out.get_source("ll")
assert "get.active.lane.mask" in ll


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 61c44f9

Please sign in to comment.