From aaef20f1f66c7faf91b597455b0678041862a117 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Fri, 17 Jan 2025 15:22:35 +0000 Subject: [PATCH] [NFC] Make toLinearLayout not return an Optional (#5636) We now support LL conversions for all our layouts, and so does XPU downstream. As such, we now make `toLinearLayout` support mandatory in line with our progressive transition of our IR towards LLs. --- include/triton/Analysis/Utility.h | 4 +- .../TritonGPU/IR/LinearLayoutConversions.h | 7 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 4 +- .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 5 +- .../Dialect/TritonGPU/Transforms/Utility.h | 7 +- lib/Analysis/Utility.cpp | 53 +++++---------- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 18 ++--- .../TritonGPUToLLVM/GatherOpToLLVM.cpp | 4 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 47 ++++++------- lib/Dialect/TritonGPU/IR/Dialect.cpp | 43 +++++------- .../TritonGPU/IR/LinearLayoutConversions.cpp | 66 ++++++++----------- lib/Dialect/TritonGPU/IR/Ops.cpp | 2 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 2 +- .../Transforms/CoalesceAsyncCopy.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 20 +++--- .../LoadStoreOpToLLVM.cpp | 3 +- unittest/Dialect/TritonGPU/DialectTest.cpp | 8 +-- 17 files changed, 119 insertions(+), 176 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 654b33ee0aba..7a73af885446 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -227,8 +227,8 @@ bool supportMMA(Value value, int version); // return nullopt). The output will be such that layout.getInDimNames() == // layout.getOutDimNames() and the conversion will not include kBlock (resp. // kWarp or kLane) if it can be avoided -std::optional -minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy); +triton::LinearLayout minimalCvtLayout(RankedTensorType srcTy, + RankedTensorType dstTy); // Conversion from `srcTy` to `dstTy` only involves reordering of registers. // There is no need for data exchange across threads, warps, or blocks. diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index b9fd3bdec19d..c8b1f164b9f6 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -38,11 +38,8 @@ namespace mlir::triton::gpu { // shared layouts with hasLeadingOffset == true) but is otherwise unused. // // Returns std::nullopt if the given layout can't be converted to an LL. -// TODO(jlebar): Remove the std::optional once all layouts are supported. -// -std::optional -toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth = std::nullopt); +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth = std::nullopt); // Given a linear layout where the input dimensions contain a "block" dimension, // this method sets the "block" dimension to 0 and removes the corresponding diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2d95d9a61e21..008103d81c21 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -506,7 +506,7 @@ We call each individual tile "rep". "SmallVector", "getContigPerThread">, InterfaceMethod<"Convert to LinearLayout.", - "std::optional", + "LinearLayout", "toLinearLayout", (ins "ArrayRef":$shape)> ]; @@ -561,7 +561,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getSizePerThread() const; - std::optional toLinearLayout(ArrayRef shape) const; + LinearLayout toLinearLayout(ArrayRef shape) const; }]; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 95b6718b5395..e6bbcb56da54 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -44,9 +44,8 @@ def TritonGPU_Dialect : Dialect { return cast(threadsPerWarp).getInt(); } - std::optional - toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth); + LinearLayout toLinearLayout(ArrayRef shape, Attribute layout, + std::optional elemBitWidth); private: LinearLayoutCache llCache; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 10749a057493..18ed968ddae0 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -206,9 +206,10 @@ enum class MMALoadType { MMALoadType getMMALoadType(Operation *loadOp); // Returns composed LinearLayout for register to shared copy -std::optional -getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, - Attribute srcEnc, Attribute dstEnc, int elemBitWidth); +triton::LinearLayout getRegToSharedLayout(MLIRContext *ctx, + ArrayRef shape, + Attribute srcEnc, Attribute dstEnc, + int elemBitWidth); } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index c4cb6276ff4b..dc344bc18937 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -66,7 +66,7 @@ SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { auto srcLayout = getSrcLayout(); auto *ctx = srcLayout.getContext(); - auto linearLayout = *toLinearLayout(getSrcShape(), srcLayout); + auto linearLayout = toLinearLayout(getSrcShape(), srcLayout); auto axis = getAxis(); auto kLane = mlir::StringAttr::get(ctx, "lane"); const auto &bases = linearLayout.getBases(); @@ -158,7 +158,7 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() { auto axis = getAxis(); auto *ctx = getSrcLayout().getContext(); auto ll = LinearEncodingAttr::get( - ctx, *toLinearLayout(getSrcShape(), getSrcLayout())); + ctx, toLinearLayout(getSrcShape(), getSrcLayout())); return ll.getThreadsPerWarp()[axis] * ll.getWarpsPerCTA()[axis]; } @@ -320,8 +320,6 @@ std::optional getWarpLayoutConvertDecomposition(RankedTensorType srcTy, RankedTensorType dstTy) { auto conversion = minimalCvtLayout(srcTy, dstTy); - if (!conversion) - return {}; MLIRContext *ctx = srcTy.getContext(); auto kRegister = StringAttr::get(ctx, "register"); @@ -329,8 +327,7 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, // We have already checked that data movement is only required within a warp, // thus we can discard the block and warp dimensions. - LinearLayout C = - conversion->sublayout({kLane, kRegister}, {kLane, kRegister}); + LinearLayout C = conversion.sublayout({kLane, kRegister}, {kLane, kRegister}); // `C` is map from `(dst_lane, dst_reg) -> (src_lane, src_reg)`. From the // perspetive of the destination lane, it tells us which register from which @@ -641,16 +638,11 @@ bool GatherLoweringHelper::isWarpLocal() { // source and index tensors, all the elements are owned by the same warp. RankedTensorType srcType = gatherOp.getSrc().getType(); RankedTensorType idxType = gatherOp.getIndices().getType(); - std::optional srcLayout = + LinearLayout srcLayout = toLinearLayout(srcType.getShape(), srcType.getEncoding()); - std::optional idxLayout = + LinearLayout idxLayout = toLinearLayout(idxType.getShape(), idxType.getEncoding()); - // FIXME: If an unsupported layout was encountered, assume the gather is not - // warp-local. - if (!srcLayout || !idxLayout) - return false; - Builder b(gatherOp.getContext()); StringAttr kBlock = b.getStringAttr("block"); StringAttr kWarp = b.getStringAttr("warp"); @@ -675,8 +667,8 @@ bool GatherLoweringHelper::isWarpLocal() { // // Which implies that changing the warp will not change the gather dimension. // And since there is no swizzling, this applies to all warps. - if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) || - !idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim)) + if (!srcLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim) || + !idxLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim)) return false; SmallVector otherDims; @@ -690,8 +682,8 @@ bool GatherLoweringHelper::isWarpLocal() { // mapping to all other dimensions must be the same for both layouts. If so, // then the warp that owns a particular index element also owns all the source // elements it could index into. - if (srcLayout->sublayout({kBlock, kWarp}, otherDims) != - idxLayout->sublayout({kBlock, kWarp}, otherDims)) + if (srcLayout.sublayout({kBlock, kWarp}, otherDims) != + idxLayout.sublayout({kBlock, kWarp}, otherDims)) return false; // The two constraints above ensure that data-movement to perform the gather @@ -702,8 +694,8 @@ bool GatherLoweringHelper::isWarpLocal() { // in the index and source tensors are the same. This means we don't need to // xor shuffle across threads before emitting index shuffles; we push warp // shuffling to layout conversions. - return srcLayout->sublayout(kLane, otherDims) == - idxLayout->sublayout(kLane, otherDims); + return srcLayout.sublayout(kLane, otherDims) == + idxLayout.sublayout(kLane, otherDims); } unsigned getNumScratchElements(ArrayRef shape) { @@ -884,21 +876,18 @@ bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, // distributed shared memory. If it's also the identity on kWarp, we can // transfer via warp-shuffles, and if it's the identity on kLane just have to // reorder the registers -std::optional minimalCvtLayout(RankedTensorType srcTy, - RankedTensorType dstTy) { +LinearLayout minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy) { MLIRContext *ctx = srcTy.getContext(); - std::optional srcLayout = + LinearLayout srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - std::optional dstLayout = + LinearLayout dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (!(srcLayout.has_value() && dstLayout.has_value())) - return std::nullopt; StringAttr kRegister = StringAttr::get(ctx, "register"); StringAttr kLane = StringAttr::get(ctx, "lane"); StringAttr kWarp = StringAttr::get(ctx, "warp"); StringAttr kBlock = StringAttr::get(ctx, "block"); - auto comp = dstLayout->invertAndCompose(*srcLayout); + auto comp = dstLayout.invertAndCompose(srcLayout); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { @@ -914,24 +903,18 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { auto layout = minimalCvtLayout(srcTy, dstTy); MLIRContext *ctx = srcTy.getContext(); - if (!layout.has_value()) { - return false; - } auto kRegister = StringAttr::get(ctx, "register"); - auto outDims = llvm::to_vector(layout->getOutDimNames()); + auto outDims = to_vector(layout.getOutDimNames()); return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister}); } bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { auto layout = minimalCvtLayout(srcTy, dstTy); MLIRContext *ctx = srcTy.getContext(); - if (!layout.has_value()) { - return false; - } auto kRegister = StringAttr::get(ctx, "register"); auto kLane = StringAttr::get(ctx, "lane"); - return llvm::to_vector(layout->getOutDimNames()) == - llvm::SmallVector{kRegister, kLane}; + return to_vector(layout.getOutDimNames()) == + SmallVector{kRegister, kLane}; } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b66aaaaa9804..f277561a0deb 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -277,24 +277,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - auto conversion = minimalCvtLayout(srcTy, dstTy); - if (!conversion.has_value()) { - return rewriter.notifyMatchFailure( - op, "NYI. srcTy and/or dstTy don't implement LLs yet"); - } + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); LinearLayout srcLayout = - *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); LinearLayout dstLayout = - *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); StringAttr kBlock = str_attr("block"); StringAttr kWarp = str_attr("warp"); StringAttr kLane = str_attr("lane"); StringAttr kRegister = str_attr("register"); - assert(to_vector(conversion->getInDimNames()) == - to_vector(conversion->getOutDimNames())); - auto dims = conversion->getInDimNames(); + assert(to_vector(conversion.getInDimNames()) == + to_vector(conversion.getOutDimNames())); + auto dims = conversion.getInDimNames(); if (llvm::is_contained(dims, kBlock)) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. @@ -320,7 +316,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } else if (llvm::is_contained(dims, kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). - return transferWithinThread(op, *conversion, adaptor, rewriter); + return transferWithinThread(op, conversion, adaptor, rewriter); } else { // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. diff --git a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp index fd342751030a..15f38930795e 100644 --- a/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp @@ -207,9 +207,9 @@ void GatherOpConversion::emitWarpLocalGather( // Compute the src and idx layouts. LinearLayout srcLayout = - *toLinearLayout(srcType.getShape(), srcType.getEncoding()); + toLinearLayout(srcType.getShape(), srcType.getEncoding()); LinearLayout idxLayout = - *toLinearLayout(idxType.getShape(), idxType.getEncoding()); + toLinearLayout(idxType.getShape(), idxType.getEncoding()); // Let `ll_src` be the source layout and `ll_idx` be the index layout. // Let `src_col` be a tuple of dimensions except the gather dimension, diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 64f3698b9a31..6570764bc65d 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -138,9 +138,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, MLIRContext *ctx = rewriter.getContext(); auto shape = type.getShape(); - std::optional ll = triton::gpu::toLinearLayout(shape, layout); - if (!ll.has_value()) - llvm::report_fatal_error("Failed to convert layout to linear layout"); + LinearLayout ll = triton::gpu::toLinearLayout(shape, layout); // TODO(jlebar): We could add strong typing if we wanted; for now this is // "stringly typed". @@ -150,7 +148,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, StringAttr kBlock = str_attr("block"); auto [laneId, warpId, blockId] = emitHardwareTuple( - loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane)); + loc, rewriter, target, withCTAOffset, ll.getInDimSize(kLane)); unsigned rank = shape.size(); SmallVector> ret; // Linear layout function is split in two parts below: @@ -162,14 +160,14 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, // // This approach produces code with lower register pressure and // less computations, compared to fused L(r,t,w,b) method. - auto idxsBase = applyLinearLayout(loc, rewriter, *ll, + auto idxsBase = applyLinearLayout(loc, rewriter, ll, {{kRegister, i32_val(0)}, {kLane, laneId}, {kWarp, warpId}, {kBlock, blockId}}); - for (unsigned reg = 0; reg < ll->getInDimSize(str_attr("register")); reg++) { + for (unsigned reg = 0; reg < ll.getInDimSize(str_attr("register")); reg++) { auto idxsReg = - ll->apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + ll.apply({{kRegister, reg}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); SmallVector> idxs; for (auto [idxBase, idxReg] : llvm::zip(idxsBase, idxsReg)) { auto dimName = idxBase.first; @@ -284,16 +282,14 @@ Value getSmemVecAddr(RankedTensorType registerTy, // This approach ensures that "absolute" tensor offsets can be // mapped to the correct shared memory addresses using // `invertAllocSharedLayout`. - std::optional regLayout = + LinearLayout regLayout = triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); - auto allocSharedLayout = triton::gpu::toLinearLayout( + LinearLayout allocSharedLayout = triton::gpu::toLinearLayout( allocShape.take_back(rank), sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); - assert(allocSharedLayout.has_value() && - "Failed to convert layout to linear layout"); - auto invertAllocSharedLayout = allocSharedLayout->invert(); + LinearLayout invertAllocSharedLayout = allocSharedLayout.invert(); auto multiDimTensorOffsets = - llvm::to_vector(applyLinearLayout(loc, rewriter, *regLayout, + llvm::to_vector(applyLinearLayout(loc, rewriter, regLayout, {{kRegister, regId}, {kLane, laneId}, {kWarp, warpId}, @@ -332,17 +328,15 @@ bool emitTransferBetweenRegistersAndShared( StringAttr kLane = str_attr("lane"); StringAttr kWarp = str_attr("warp"); - auto regToSharedLayout = getRegToSharedLayout( + LinearLayout regToSharedLayout = getRegToSharedLayout( ctx, shape, registerTy.getEncoding(), sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); - if (!regToSharedLayout.has_value()) - return false; // TODO(jlebar): We don't currently support loading from shared memory in a // different CTA. We'd need to emit `mapa.shared::cluster` instructions. - for (int inBlock = 1; inBlock < regToSharedLayout->getInDimSize(kBlock); + for (int inBlock = 1; inBlock < regToSharedLayout.getInDimSize(kBlock); inBlock *= 2) { - auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout->apply( + auto idx = llvm::to_vector(llvm::make_second_range(regToSharedLayout.apply( {{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, inBlock}}))); // offsetX1, ..., offsetXN must all be 0. if (!llvm::all_of(ArrayRef(idx).drop_back(1), @@ -368,20 +362,20 @@ bool emitTransferBetweenRegistersAndShared( // which have known strides. This would allow us to vectorize across multiple // shmem out dimensions where possible. const int vecElems = - std::min(regToSharedLayout->getNumConsecutiveInOut(), + std::min(regToSharedLayout.getNumConsecutiveInOut(), maxVecElems.value_or(std::numeric_limits::max())); auto [laneId, warpId, blockId] = emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false, - regToSharedLayout->getInDimSize(kLane)); + regToSharedLayout.getInDimSize(kLane)); - int numElems = regToSharedLayout->getInDimSize(kRegister); + int numElems = regToSharedLayout.getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); Value zero = i32_val(0); SmallVector ret; for (int i = 0; i < numElems / vecElems; i++) { auto vecAddr = getSmemVecAddr( - registerTy, sharedTy, elemLlvmTy, loc, rewriter, *regToSharedLayout, + registerTy, sharedTy, elemLlvmTy, loc, rewriter, regToSharedLayout, i32_val(i * vecElems), laneId, warpId, smemObj); perVectorCallback(vecTy, vecAddr); @@ -450,8 +444,6 @@ SmallVector> emitOffsetForLayout(Attribute layout, unsigned rank = shape.size(); auto ll = triton::gpu::toLinearLayout(shape, layout); - if (!ll.has_value()) - llvm::report_fatal_error("Unsupported layout"); StringAttr kRegister = str_attr("register"); StringAttr kLane = str_attr("lane"); @@ -459,9 +451,8 @@ SmallVector> emitOffsetForLayout(Attribute layout, StringAttr kBlock = str_attr("block"); SmallVector> offsets; - for (int i = 0; i < ll->getInDimSize(str_attr("register")); i++) { - auto idxs = - ll->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + for (int i = 0; i < ll.getInDimSize(str_attr("register")); i++) { + auto idxs = ll.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); assert(idxs.size() == rank); for (unsigned k = 0; k < rank; ++k) { assert(idxs[k].first == str_attr("dim" + std::to_string(k))); @@ -632,7 +623,7 @@ std::tuple, Value> delinearize(RewriterBase &rewriter, Location loc, triton::gpu::DistributedEncodingTrait layout, ArrayRef shape, StringAttr dimName, Value linear) { - auto ll = *triton::gpu::toLinearLayout(shape, layout); + auto ll = triton::gpu::toLinearLayout(shape, layout); auto linearLayout = triton::gpu::LinearEncodingAttr::get(rewriter.getContext(), ll); assert(ll.hasInDim(dimName)); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 81561ead2d0a..b9606df7ea85 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -92,7 +92,7 @@ unsigned getWarpSize(Attribute layout) { SmallVector getThreadsPerWarpWithUniqueData(Attribute layout, ArrayRef tensorShape) { - auto linearLayout = *toLinearLayout(tensorShape, layout); + auto linearLayout = toLinearLayout(tensorShape, layout); auto llAttr = LinearEncodingAttr::get(layout.getContext(), linearLayout); return llAttr.getThreadsPerWarp(); } @@ -109,7 +109,7 @@ SmallVector getWarpsPerCTA(Attribute layout) { SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { - auto linearLayout = *toLinearLayout(tensorShape, layout); + auto linearLayout = toLinearLayout(tensorShape, layout); auto llAttr = LinearEncodingAttr::get(layout.getContext(), linearLayout); return llAttr.getWarpsPerCTA(); } @@ -151,7 +151,7 @@ SmallVector getUniqueContigPerThread(Attribute layout, // with shape [128, 128] and size=[4, 1], that is tiled in the second // dimension, then the default path will return [4, 1], but this path will // return [4, 128]! - auto linearLayout = *toLinearLayout(shape, layout); + auto linearLayout = toLinearLayout(shape, layout); auto llAttr = LinearEncodingAttr::get(layout.getContext(), linearLayout); return llAttr.getContigPerThread(); } @@ -1559,8 +1559,7 @@ SmallVector LinearEncodingAttr::getOrder() const { return orderPerDim(StringAttr::get(getContext(), "register"), order); } -std::optional -LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { +LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { auto ll = getLinearLayout(); auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); llvm::SmallDenseMap namedShape; @@ -1583,7 +1582,7 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep // the invariant that the shape of the LL is that of the tensor // We choose the former for BC - auto scaledLayout = get(getContext(), *toLinearLayout(shape)); + auto scaledLayout = get(getContext(), toLinearLayout(shape)); auto kRegister = StringAttr::get(getContext(), "register"); return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false); } @@ -2932,11 +2931,7 @@ struct TritonGPUInferLayoutInterface // Once LinearLayouts are more widely used, we can remove // inferReshapeOpLegacyEncoding and simply use LLs. auto *ctx = getContext(); - auto src = triton::gpu::toLinearLayout(srcShape, srcEnc); - if (!src) { - return emitOptionalError(loc, - "src encoding does not support linear layout"); - } + auto src = toLinearLayout(srcShape, srcEnc); if (product(srcShape) != product(dstShape)) { return emitOptionalError(loc, "numel of dst shape does not match " @@ -2949,12 +2944,12 @@ struct TritonGPUInferLayoutInterface llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { newOutDims.emplace_back(dim, size); } - auto srcOutDims = llvm::to_vector(src->getOutDimNames()); + auto srcOutDims = to_vector(src.getOutDimNames()); // reshapeOp assumes minor-to-major, so we need to transpose the out dims // before the reshape std::reverse(srcOutDims.begin(), srcOutDims.end()); std::reverse(newOutDims.begin(), newOutDims.end()); - auto dst = src->transposeOuts(srcOutDims) + auto dst = src.transposeOuts(srcOutDims) .reshapeOuts(newOutDims) .transposeOuts(standardOutDimNames(ctx, newRank)); dstEnc = LinearEncodingAttr::get(ctx, dst); @@ -3142,10 +3137,7 @@ std::string getSharedLayoutStr(RankedTensorType tensorType, if (!layout) return ""; - std::optional ll = - triton::gpu::toLinearLayout(tensorType.getShape(), layout); - if (!ll.has_value()) - llvm::report_fatal_error("Failed to convert layout to linear layout"); + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); StringAttr kOffset = StringAttr::get(tensorType.getContext(), "offset"); StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); @@ -3171,7 +3163,7 @@ std::string getSharedLayoutStr(RankedTensorType tensorType, {kOffset, offset}, }; - SmallVector> outputs = ll->apply(inputs); + SmallVector> outputs = ll.apply(inputs); std::string sharedInfo = "("; std::string &value = elementMapping[idx]; @@ -3263,17 +3255,14 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block"); - std::optional ll = - triton::gpu::toLinearLayout(tensorType.getShape(), layout); - if (!ll.has_value()) - llvm::report_fatal_error("Failed to convert layout to linear layout"); + LinearLayout ll = triton::gpu::toLinearLayout(tensorType.getShape(), layout); int64_t tensorSize = product(tensorType.getShape()); std::vector elementMapping(tensorSize); std::vector threadMapping; - unsigned threadsPerWarp = ll->getInDimSize(kLane); - unsigned numWarpsPerCTA = ll->getInDimSize(kWarp); - unsigned numBlocks = ll->getInDimSize(kBlock); - int numElementsPerThreads = ll->getInDimSize(kRegister); + unsigned threadsPerWarp = ll.getInDimSize(kLane); + unsigned numWarpsPerCTA = ll.getInDimSize(kWarp); + unsigned numBlocks = ll.getInDimSize(kBlock); + int numElementsPerThreads = ll.getInDimSize(kRegister); for (int blockId = 0; blockId < numBlocks; ++blockId) { for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { for (int tid = 0; tid < threadsPerWarp; ++tid) { @@ -3284,7 +3273,7 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, {kLane, tid}, {kRegister, idx}}; SmallVector> outputs = - ll->apply(inputs); + ll.apply(inputs); int32_t linearizedIdx = 0; int stride = 1; for (int i = outputs.size() - 1; i >= 0; i--) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index aeaf80780eb0..2d2c7cd61e0f 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -279,7 +279,7 @@ LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx, } // anonymous namespace -std::optional +LinearLayout AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { int rank = shape.size(); assert(rank == getWarpsPerCTA().size()); @@ -367,9 +367,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } -std::optional -mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, - ArrayRef shape) { +LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, + ArrayRef shape) { // Current linear layout conversion for dot operand is only necessary to // enable LDS bypass for operand B in the MFMA dot path. To achieve @@ -479,7 +478,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); } -std::optional +LinearLayout AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { int rank = shape.size(); assert(rank == getWarpsPerCTA().size()); @@ -570,9 +569,8 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } -std::optional -wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, - ArrayRef shape) { +LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, + ArrayRef shape) { auto wmmaLayout = llvm::cast(dotWmmaLayout.getParent()); auto rank = shape.size(); bool hasBatchDim = rank == 3; @@ -639,7 +637,7 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout, return combineCtaCgaWithShape(ctaLayout, wmmaLayout.getCTALayout(), shape); } -std::optional +LinearLayout BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { assert(shape.size() == getOrder().size()); MLIRContext *ctx = getContext(); @@ -653,9 +651,8 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } -std::optional -fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, - ArrayRef shape) { +LinearLayout fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout, + ArrayRef shape) { int rank = shape.size(); auto blocked = cast(operandLayout.getParent()); MLIRContext *ctx = operandLayout.getContext(); @@ -736,7 +733,7 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, return ctaLayout; } -std::optional +LinearLayout NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { auto ctx = getContext(); int rank = shape.size(); @@ -792,7 +789,7 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef shape, return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); } -std::optional +LinearLayout DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { auto parent = getParent(); if (auto blockedLayout = mlir::dyn_cast(parent)) { @@ -801,24 +798,19 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { return mfmaDotToLinearLayout(*this, shape); } else if (auto wmmaLayout = mlir::dyn_cast(parent)) { return wmmaDotOperandToLinearLayout(*this, shape); - } else if (auto mma = mlir::dyn_cast(parent)) { + } else { + auto mma = mlir::cast(parent); return nvidiaDotToLinearLayout(shape, *this); } - return std::nullopt; } -std::optional -SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { +LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { MLIRContext *ctx = getContext(); // First compute the linear layout for this layout's parent. SmallVector parentShape(shape); parentShape.insert(parentShape.begin() + getDim(), 1); - std::optional parentLL = - triton::gpu::toLinearLayout(parentShape, getParent()); - if (!parentLL.has_value()) { - return std::nullopt; - } + LinearLayout parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); // Remove dimension getDim() from the parent layout. // @@ -829,19 +821,19 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { // 3. Fix up duplicate registers introduced by slicing. auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); LinearLayout transform = LinearLayout::empty(); - for (auto [idx, outDim] : llvm::enumerate(parentLL->getOutDimNames())) { + for (auto [idx, outDim] : llvm::enumerate(parentLL.getOutDimNames())) { if (idx == getDim()) { // Because we're multiplying by all zeros, we could replace outDimNames[0] // with any other valid out-dim; the layout will be the same. - transform *= LinearLayout::zeros1D(parentLL->getOutDimSize(outDim), - outDim, outDimNames[0]); + transform *= LinearLayout::zeros1D(parentLL.getOutDimSize(outDim), outDim, + outDimNames[0]); } else { transform *= - LinearLayout::identity1D(parentLL->getOutDimSize(outDim), outDim, + LinearLayout::identity1D(parentLL.getOutDimSize(outDim), outDim, outDimNames[idx - (idx < getDim() ? 0 : 1)]); } } - LinearLayout sliceLL = parentLL->compose(transform); + LinearLayout sliceLL = parentLL.compose(transform); // Step 3: Along the "register" dim, remove any all-zero bases. auto bases = sliceLL.getBases(); @@ -874,20 +866,22 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { return ret; } -std::optional +LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth) { CacheKey key{std::vector(shape.begin(), shape.end()), layout, elemBitWidth}; - auto result = llCache.get(key); - if (result.has_value()) { - return result; + if (auto result = llCache.get(key)) { + return *result; } // Layouts are distributed or shared in triton core + // To add a new layout add an else-if clause + LinearLayout result = LinearLayout::empty(); if (auto distributed = dyn_cast(layout)) { result = distributed.toLinearLayout(shape); - } else if (auto shared = dyn_cast(layout)) { + } else { + auto shared = dyn_cast(layout); if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); result = sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); @@ -896,13 +890,11 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, } } - if (result.has_value()) { - llCache.set(std::move(key), *result); - } + llCache.set(std::move(key), result); return result; } -std::optional +LinearLayout toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { auto *ctx = layout.getContext(); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 989b5d38dc88..814f948b2953 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -47,7 +47,7 @@ struct CanonicalizeConvertFromReshape auto dstType = convert.getType(); auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding()); auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding()); - if (srcLL && dstLL && *srcLL == *dstLL) { + if (srcLL == dstLL) { rewriter.replaceOpWithNewOp( op, op.getType(), convert.getSrc(), op.getAllowReorder()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 084224db124d..03a5db3889d6 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -472,7 +472,7 @@ class DecomposeScaledBlocked // Extract warp layout from dotAEncoding // In the future we'll have some nice division utils, but until then... - auto dotLL = *newAEncoding.toLinearLayout(a.getType().getShape()); + auto dotLL = newAEncoding.toLinearLayout(a.getType().getShape()); LinearLayout::BasesT scaleBases = dotLL.getBases(); auto kWarp = StringAttr::get(ctx, "warp"); auto &warpBases = scaleBases[kWarp]; diff --git a/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp index 2d634fc6fa7b..b08407d5e233 100644 --- a/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp @@ -56,7 +56,7 @@ struct ClipAsyncCopySizePerThread auto regToSharedLayout = getRegToSharedLayout(rewriter.getContext(), srcTy.getShape(), blockEnc, sharedEnc, elemBitWidth); - auto copyContigSize = regToSharedLayout->getNumConsecutiveInOut(); + auto copyContigSize = regToSharedLayout.getNumConsecutiveInOut(); // obtain block sizePerThread along contig dim auto sizePerThread = blockEnc.getSizePerThread(); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 27cb71638f5f..59ebc445ecd1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1189,19 +1189,15 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } -std::optional -getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, - Attribute srcEnc, Attribute dstEnc, int elemBitWidth) { +LinearLayout getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, + Attribute srcEnc, Attribute dstEnc, + int elemBitWidth) { StringAttr kBlock = StringAttr::get(ctx, ("block")); int rank = shape.size(); - std::optional regLayout = - triton::gpu::toLinearLayout(shape, srcEnc); - std::optional sharedLayout = + LinearLayout regLayout = triton::gpu::toLinearLayout(shape, srcEnc); + LinearLayout sharedLayout = triton::gpu::toLinearLayout(shape, dstEnc, elemBitWidth); - if (!regLayout.has_value() || !sharedLayout.has_value()) { - return std::nullopt; - } auto sharedOrder = triton::gpu::getOrder(dstEnc); // sharedLayout's in-dims are currently (offset, block). Reshape to @@ -1217,12 +1213,12 @@ getRegToSharedLayout(MLIRContext *ctx, ArrayRef shape, multiDimSharedSize.push_back( {StringAttr::get(ctx, ("offset" + std::to_string(dim))), size}); } - multiDimSharedSize.push_back({kBlock, sharedLayout->getInDimSize(kBlock)}); - sharedLayout = sharedLayout->reshapeIns(multiDimSharedSize); + multiDimSharedSize.push_back({kBlock, sharedLayout.getInDimSize(kBlock)}); + sharedLayout = sharedLayout.reshapeIns(multiDimSharedSize); // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1, // ..., offsetXN, block), where the offsetX's are in minor-to-major order. - return regLayout->invertAndCompose(*sharedLayout); + return regLayout.invertAndCompose(sharedLayout); } } // namespace mlir diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a52a3d205afa..d4abcfbebe2f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -49,8 +49,7 @@ llvm::MapVector getFreeVariableMasks(Type type) { } auto ll = ttg::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding()); - assert(ll && "failed to convert to linear layout"); - return ll->getFreeVariableMasks(); + return ll.getFreeVariableMasks(); } Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index a3cc65605fd2..d2307ddca826 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -146,9 +146,9 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") << " " << stringifyLLVMType(inferredEnc) << " -> " << triton::join(srcTy.getShape(), "x") - << " gave the wrong result. Expected " << srcLinear->toString() + << " gave the wrong result. Expected " << srcLinear.toString() << " but " - << "got " << inferredSrcLinear->toString() << ".\n"; + << "got " << inferredSrcLinear.toString() << ".\n"; } // The funtional characterisation of resize is that, if we have a srcLayout @@ -156,7 +156,7 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, // when considered as C-contiguous. auto makeFlattenedCContig = [](ArrayRef shape, Attribute layout) { auto ctx = layout.getContext(); - auto linear = *toLinearLayout(shape, layout); + auto linear = toLinearLayout(shape, layout); auto dims = standardOutDimNames(ctx, shape.size()); std::reverse(dims.begin(), dims.end()); return linear.transposeOuts(dims).reshapeOuts( @@ -515,7 +515,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { } // Create LinearEncodingAttr from the LinearLayout - auto linearLayout = *distributedEncoding.toLinearLayout(shape); + auto linearLayout = distributedEncoding.toLinearLayout(shape); auto linearEncoding = triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout);