Skip to content

Commit

Permalink
[NFC] Make toLinearLayout not return an Optional (#5636)
Browse files Browse the repository at this point in the history
We now support LL conversions for all our layouts, and so does XPU
downstream. As such, we now make `toLinearLayout` support mandatory
in line with our progressive transition of our IR towards LLs.
  • Loading branch information
lezcano authored Jan 17, 2025
1 parent 6556ec6 commit aaef20f
Show file tree
Hide file tree
Showing 17 changed files with 119 additions and 176 deletions.
4 changes: 2 additions & 2 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ bool supportMMA(Value value, int version);
// return nullopt). The output will be such that layout.getInDimNames() ==
// layout.getOutDimNames() and the conversion will not include kBlock (resp.
// kWarp or kLane) if it can be avoided
std::optional<mlir::triton::LinearLayout>
minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy);
triton::LinearLayout minimalCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

// Conversion from `srcTy` to `dstTy` only involves reordering of registers.
// There is no need for data exchange across threads, warps, or blocks.
Expand Down
7 changes: 2 additions & 5 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ namespace mlir::triton::gpu {
// shared layouts with hasLeadingOffset == true) but is otherwise unused.
//
// Returns std::nullopt if the given layout can't be converted to an LL.
// TODO(jlebar): Remove the std::optional once all layouts are supported.
//
std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth = std::nullopt);
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth = std::nullopt);

// Given a linear layout where the input dimensions contain a "block" dimension,
// this method sets the "block" dimension to 0 and removes the corresponding
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ We call each individual tile "rep".
"SmallVector<unsigned>",
"getContigPerThread">,
InterfaceMethod<"Convert to LinearLayout.",
"std::optional<LinearLayout>",
"LinearLayout",
"toLinearLayout",
(ins "ArrayRef<int64_t>":$shape)>
];
Expand Down Expand Up @@ -561,7 +561,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},

SmallVector<unsigned> getSizePerThread() const;

std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
}];
}

Expand Down
5 changes: 2 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def TritonGPU_Dialect : Dialect {
return cast<IntegerAttr>(threadsPerWarp).getInt();
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth);
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth);

private:
LinearLayoutCache llCache;
Expand Down
7 changes: 4 additions & 3 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ enum class MMALoadType {
MMALoadType getMMALoadType(Operation *loadOp);

// Returns composed LinearLayout for register to shared copy
std::optional<triton::LinearLayout>
getRegToSharedLayout(MLIRContext *ctx, ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc, int elemBitWidth);
triton::LinearLayout getRegToSharedLayout(MLIRContext *ctx,
ArrayRef<int64_t> shape,
Attribute srcEnc, Attribute dstEnc,
int elemBitWidth);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
53 changes: 18 additions & 35 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
auto srcLayout = getSrcLayout();
auto *ctx = srcLayout.getContext();
auto linearLayout = *toLinearLayout(getSrcShape(), srcLayout);
auto linearLayout = toLinearLayout(getSrcShape(), srcLayout);
auto axis = getAxis();
auto kLane = mlir::StringAttr::get(ctx, "lane");
const auto &bases = linearLayout.getBases();
Expand Down Expand Up @@ -158,7 +158,7 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto axis = getAxis();
auto *ctx = getSrcLayout().getContext();
auto ll = LinearEncodingAttr::get(
ctx, *toLinearLayout(getSrcShape(), getSrcLayout()));
ctx, toLinearLayout(getSrcShape(), getSrcLayout()));
return ll.getThreadsPerWarp()[axis] * ll.getWarpsPerCTA()[axis];
}

Expand Down Expand Up @@ -320,17 +320,14 @@ std::optional<DecomposedWarpConversion>
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto conversion = minimalCvtLayout(srcTy, dstTy);
if (!conversion)
return {};

MLIRContext *ctx = srcTy.getContext();
auto kRegister = StringAttr::get(ctx, "register");
auto kLane = StringAttr::get(ctx, "lane");

// We have already checked that data movement is only required within a warp,
// thus we can discard the block and warp dimensions.
LinearLayout C =
conversion->sublayout({kLane, kRegister}, {kLane, kRegister});
LinearLayout C = conversion.sublayout({kLane, kRegister}, {kLane, kRegister});

// `C` is map from `(dst_lane, dst_reg) -> (src_lane, src_reg)`. From the
// perspetive of the destination lane, it tells us which register from which
Expand Down Expand Up @@ -641,16 +638,11 @@ bool GatherLoweringHelper::isWarpLocal() {
// source and index tensors, all the elements are owned by the same warp.
RankedTensorType srcType = gatherOp.getSrc().getType();
RankedTensorType idxType = gatherOp.getIndices().getType();
std::optional<LinearLayout> srcLayout =
LinearLayout srcLayout =
toLinearLayout(srcType.getShape(), srcType.getEncoding());
std::optional<LinearLayout> idxLayout =
LinearLayout idxLayout =
toLinearLayout(idxType.getShape(), idxType.getEncoding());

// FIXME: If an unsupported layout was encountered, assume the gather is not
// warp-local.
if (!srcLayout || !idxLayout)
return false;

Builder b(gatherOp.getContext());
StringAttr kBlock = b.getStringAttr("block");
StringAttr kWarp = b.getStringAttr("warp");
Expand All @@ -675,8 +667,8 @@ bool GatherLoweringHelper::isWarpLocal() {
//
// Which implies that changing the warp will not change the gather dimension.
// And since there is no swizzling, this applies to all warps.
if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) ||
!idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim))
if (!srcLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim) ||
!idxLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim))
return false;

SmallVector<StringAttr> otherDims;
Expand All @@ -690,8 +682,8 @@ bool GatherLoweringHelper::isWarpLocal() {
// mapping to all other dimensions must be the same for both layouts. If so,
// then the warp that owns a particular index element also owns all the source
// elements it could index into.
if (srcLayout->sublayout({kBlock, kWarp}, otherDims) !=
idxLayout->sublayout({kBlock, kWarp}, otherDims))
if (srcLayout.sublayout({kBlock, kWarp}, otherDims) !=
idxLayout.sublayout({kBlock, kWarp}, otherDims))
return false;

// The two constraints above ensure that data-movement to perform the gather
Expand All @@ -702,8 +694,8 @@ bool GatherLoweringHelper::isWarpLocal() {
// in the index and source tensors are the same. This means we don't need to
// xor shuffle across threads before emitting index shuffles; we push warp
// shuffling to layout conversions.
return srcLayout->sublayout(kLane, otherDims) ==
idxLayout->sublayout(kLane, otherDims);
return srcLayout.sublayout(kLane, otherDims) ==
idxLayout.sublayout(kLane, otherDims);
}

unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
Expand Down Expand Up @@ -884,21 +876,18 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
// distributed shared memory. If it's also the identity on kWarp, we can
// transfer via warp-shuffles, and if it's the identity on kLane just have to
// reorder the registers
std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
LinearLayout minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
LinearLayout srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
LinearLayout dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");

auto comp = dstLayout->invertAndCompose(*srcLayout);
auto comp = dstLayout.invertAndCompose(srcLayout);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
Expand All @@ -914,24 +903,18 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
if (!layout.has_value()) {
return false;
}
auto kRegister = StringAttr::get(ctx, "register");
auto outDims = llvm::to_vector(layout->getOutDimNames());
auto outDims = to_vector(layout.getOutDimNames());
return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister});
}

bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
if (!layout.has_value()) {
return false;
}
auto kRegister = StringAttr::get(ctx, "register");
auto kLane = StringAttr::get(ctx, "lane");
return llvm::to_vector(layout->getOutDimNames()) ==
llvm::SmallVector<StringAttr, 2>{kRegister, kLane};
return to_vector(layout.getOutDimNames()) ==
SmallVector<StringAttr, 2>{kRegister, kLane};
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
Expand Down
18 changes: 7 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,24 +277,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();

auto conversion = minimalCvtLayout(srcTy, dstTy);
if (!conversion.has_value()) {
return rewriter.notifyMatchFailure(
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
}
LinearLayout conversion = minimalCvtLayout(srcTy, dstTy);
LinearLayout srcLayout =
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
LinearLayout dstLayout =
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());

StringAttr kBlock = str_attr("block");
StringAttr kWarp = str_attr("warp");
StringAttr kLane = str_attr("lane");
StringAttr kRegister = str_attr("register");

assert(to_vector(conversion->getInDimNames()) ==
to_vector(conversion->getOutDimNames()));
auto dims = conversion->getInDimNames();
assert(to_vector(conversion.getInDimNames()) ==
to_vector(conversion.getOutDimNames()));
auto dims = conversion.getInDimNames();
if (llvm::is_contained(dims, kBlock)) {
// Case 1: Transfer between values in different CTAs.
// This requires moving values through distributed shared memory.
Expand All @@ -320,7 +316,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
} else if (llvm::is_contained(dims, kRegister)) {
// Case 4. Transfer between values in the same thread, in which case we
// simply reorder the elements of adaptor.getSrc().
return transferWithinThread(op, *conversion, adaptor, rewriter);
return transferWithinThread(op, conversion, adaptor, rewriter);
} else {
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ void GatherOpConversion::emitWarpLocalGather(

// Compute the src and idx layouts.
LinearLayout srcLayout =
*toLinearLayout(srcType.getShape(), srcType.getEncoding());
toLinearLayout(srcType.getShape(), srcType.getEncoding());
LinearLayout idxLayout =
*toLinearLayout(idxType.getShape(), idxType.getEncoding());
toLinearLayout(idxType.getShape(), idxType.getEncoding());

// Let `ll_src` be the source layout and `ll_idx` be the index layout.
// Let `src_col` be a tuple of dimensions except the gather dimension,
Expand Down
47 changes: 19 additions & 28 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
MLIRContext *ctx = rewriter.getContext();
auto shape = type.getShape();

std::optional<LinearLayout> ll = triton::gpu::toLinearLayout(shape, layout);
if (!ll.has_value())
llvm::report_fatal_error("Failed to convert layout to linear layout");
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);

// TODO(jlebar): We could add strong typing if we wanted; for now this is
// "stringly typed".
Expand All @@ -150,7 +148,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
StringAttr kBlock = str_attr("block");

auto [laneId, warpId, blockId] = emitHardwareTuple(
loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane));
loc, rewriter, target, withCTAOffset, ll.getInDimSize(kLane));
unsigned rank = shape.size();
SmallVector<SmallVector<Value>> ret;
// Linear layout function is split in two parts below:
Expand All @@ -162,14 +160,14 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
//
// This approach produces code with lower register pressure and
// less computations, compared to fused L(r,t,w,b) method.
auto idxsBase = applyLinearLayout(loc, rewriter, *ll,
auto idxsBase = applyLinearLayout(loc, rewriter, ll,
{{kRegister, i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, blockId}});
for (unsigned reg = 0; reg < ll->getInDimSize(str_attr("register")); reg++) {
for (unsigned reg = 0; reg < ll.getInDimSize(str_attr("register")); reg++) {
auto idxsReg =
ll->apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}});
ll.apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}});
SmallVector<std::pair<StringAttr, Value>> idxs;
for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) {
auto dimName = idxBase.first;
Expand Down Expand Up @@ -284,16 +282,14 @@ Value getSmemVecAddr(RankedTensorType registerTy,
// This approach ensures that "absolute" tensor offsets can be
// mapped to the correct shared memory addresses using
// `invertAllocSharedLayout`.
std::optional<LinearLayout> regLayout =
LinearLayout regLayout =
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
auto allocSharedLayout = triton::gpu::toLinearLayout(
LinearLayout allocSharedLayout = triton::gpu::toLinearLayout(
allocShape.take_back(rank), sharedTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth());
assert(allocSharedLayout.has_value() &&
"Failed to convert layout to linear layout");
auto invertAllocSharedLayout = allocSharedLayout->invert();
LinearLayout invertAllocSharedLayout = allocSharedLayout.invert();
auto multiDimTensorOffsets =
llvm::to_vector(applyLinearLayout(loc, rewriter, *regLayout,
llvm::to_vector(applyLinearLayout(loc, rewriter, regLayout,
{{kRegister, regId},
{kLane, laneId},
{kWarp, warpId},
Expand Down Expand Up @@ -332,17 +328,15 @@ bool emitTransferBetweenRegistersAndShared(
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");

auto regToSharedLayout = getRegToSharedLayout(
LinearLayout regToSharedLayout = getRegToSharedLayout(
ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(),
elemLlvmTy.getIntOrFloatBitWidth());
if (!regToSharedLayout.has_value())
return false;

// TODO(jlebar): We don't currently support loading from shared memory in a
// different CTA. We'd need to emit `mapa.shared::cluster` instructions.
for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock);
for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock);
inBlock *= 2) {
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply(
auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply(
{{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}})));
// offsetX1, ..., offsetXN must all be 0.
if (!llvm::all_of(ArrayRef(idx).drop_back(1),
Expand All @@ -368,20 +362,20 @@ bool emitTransferBetweenRegistersAndShared(
// which have known strides. This would allow us to vectorize across multiple
// shmem out dimensions where possible.
const int vecElems =
std::min(regToSharedLayout->getNumConsecutiveInOut(),
std::min(regToSharedLayout.getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

auto [laneId, warpId, blockId] =
emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false,
regToSharedLayout->getInDimSize(kLane));
regToSharedLayout.getInDimSize(kLane));

int numElems = regToSharedLayout->getInDimSize(kRegister);
int numElems = regToSharedLayout.getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
Value zero = i32_val(0);
SmallVector<Value> ret;
for (int i = 0; i < numElems / vecElems; i++) {
auto vecAddr = getSmemVecAddr(
registerTy, sharedTy, elemLlvmTy, loc, rewriter, *regToSharedLayout,
registerTy, sharedTy, elemLlvmTy, loc, rewriter, regToSharedLayout,
i32_val(i * vecElems), laneId, warpId, smemObj);

perVectorCallback(vecTy, vecAddr);
Expand Down Expand Up @@ -450,18 +444,15 @@ SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
unsigned rank = shape.size();

auto ll = triton::gpu::toLinearLayout(shape, layout);
if (!ll.has_value())
llvm::report_fatal_error("Unsupported layout");

StringAttr kRegister = str_attr("register");
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

SmallVector<SmallVector<unsigned>> offsets;
for (int i = 0; i < ll->getInDimSize(str_attr("register")); i++) {
auto idxs =
ll->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}});
for (int i = 0; i < ll.getInDimSize(str_attr("register")); i++) {
auto idxs = ll.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}});
assert(idxs.size() == rank);
for (unsigned k = 0; k < rank; ++k) {
assert(idxs[k].first == str_attr("dim" + std::to_string(k)));
Expand Down Expand Up @@ -632,7 +623,7 @@ std::tuple<SmallVector<Value>, Value>
delinearize(RewriterBase &rewriter, Location loc,
triton::gpu::DistributedEncodingTrait layout,
ArrayRef<int64_t> shape, StringAttr dimName, Value linear) {
auto ll = *triton::gpu::toLinearLayout(shape, layout);
auto ll = triton::gpu::toLinearLayout(shape, layout);
auto linearLayout =
triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll);
assert(ll.hasInDim(dimName));
Expand Down
Loading

0 comments on commit aaef20f

Please sign in to comment.