Skip to content

Commit

Permalink
[LAYOUTS] Add support for a generic transpose via LLs (#5403)
Browse files Browse the repository at this point in the history
As per title.
  • Loading branch information
lezcano authored Jan 19, 2025
1 parent 4571fd9 commit 1d01b72
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
DEBUG_TYPE` throughout LLVM and Triton) in order to allow the debug output to
be less noisy. `TRITON_LLVM_DEBUG_ONLY` allows for one or more comma
separated values to be specified (eg
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions` or
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions"` or
`TRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"`).
- `TRITON_ENABLE_ASAN=1` invokes the LLVM address sanitizer for
memory leak and out of bounds access detection. Currently only supported on the AMD
Expand Down
9 changes: 5 additions & 4 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class DialectInferLayoutInterface
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}

virtual LogicalResult
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int32_t> order,
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int64_t> shape,
ArrayRef<int32_t> order,
Attribute &resultEncoding) const = 0;

virtual LogicalResult
Expand Down Expand Up @@ -65,9 +66,9 @@ class DialectInferLayoutInterface

// Check if two layouts are structurally the same, even if their names are
// different
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const = 0;
virtual LogicalResult
verifyLayoutsAreEqual(ArrayRef<int64_t> shape, Attribute expected,
Attribute got, std::optional<Location> loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [

def TT_TransOp : TT_Op<"trans", [Pure,
TransposeOpInterface,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
InferTypeOpAdaptorWithIsCompatible,
SameOperandsAndResultElementType]> {

let summary = "rearrange the dimensions of a tensor";
Expand Down
21 changes: 4 additions & 17 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,23 +297,10 @@ struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
LogicalResult
matchAndRewrite(TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = cast<RankedTensorType>(op.getType());
if (auto enc =
mlir::dyn_cast<BlockedEncodingAttr>(resultTy.getEncoding())) {
// If the dst encoding is blocked, then TransOp::inferReturnTypes
// ensures that:
// - the src encoding is also blocked, and
// - the translation from src to dst is just a "renaming" of the
// registers, i.e. each thread has exactly the same values.
// Thus the transpose op simply returns the same values it got.
auto vals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
Value ret = packLLElements(loc, this->getTypeConverter(), vals, rewriter,
resultTy);
rewriter.replaceOp(op, ret);
return success();
}
return emitOptionalError(loc, "unsupported encoding for MemDescTransOp");
// By construction, TransOp::inferReturnTypes ensures that the src encoding
// is the same as the dst encoding so that this op is a no-op.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
};

Expand Down
35 changes: 28 additions & 7 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,14 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
}

LogicalResult TransOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
MLIRContext *context, std::optional<Location> location,
TransOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {

// type is the same as the input
auto argTy = cast<RankedTensorType>(operands[0].getType());
auto order = properties.as<Properties *>()->order.asArrayRef();
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);
auto argTy = cast<RankedTensorType>(adaptor.getSrc().getType());
auto shape = argTy.getShape();
auto order = adaptor.getOrder();
SmallVector<int64_t> retShape = applyPermutation(shape, order);

auto retEltTy = argTy.getElementType();
Attribute argEncoding = argTy.getEncoding();
Expand All @@ -224,7 +225,7 @@ LogicalResult TransOp::inferReturnTypes(
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferTransOpEncoding(argEncoding, order, retEncoding)
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
.failed()) {
return failure();
}
Expand All @@ -234,6 +235,26 @@ LogicalResult TransOp::inferReturnTypes(
return success();
}

bool TransOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
assert(lhs.size() == rhs.size());
assert(lhs.size() == 1);
auto lhsType = cast<RankedTensorType>(lhs[0]);
auto rhsType = cast<RankedTensorType>(rhs[0]);

if (lhsType.getShape() != rhsType.getShape())
return false;

auto lhsEnc = lhsType.getEncoding();
auto rhsEnc = rhsType.getEncoding();
// If there's no encoding or the encodings are the same
if (lhsEnc == rhsEnc)
return true;

return cast<DialectInferLayoutInterface>(&lhsEnc.getDialect())
->verifyLayoutsAreEqual(lhsType.getShape(), lhsEnc, rhsEnc, {})
.succeeded();
}

//-- DotOp --
LogicalResult
DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
Expand Down
42 changes: 28 additions & 14 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2525,12 +2525,14 @@ struct TritonGPUInferLayoutInterface
// = inverse(trans.order) * inputEnc.order.
//
LogicalResult inferTransOpEncoding(Attribute operandEncoding,
ArrayRef<int64_t> shape,
ArrayRef<int32_t> order, // trans order
Attribute &resultEncoding) const override {
// Note: inferFooOpEncoding should not crash if given invalid inputs, which
// happens when someone creates invalid IR. If we return failure() on
// error, then MLIR will generate a helpful error message.

auto *ctx = getDialect()->getContext();
auto invOrder = inversePermutation(order);
SmallVector<unsigned> invOrderUnsigned(invOrder.begin(), invOrder.end());

Expand All @@ -2544,8 +2546,7 @@ struct TritonGPUInferLayoutInterface
}

return CTALayoutAttr::get(
getDialect()->getContext(),
applyPermutation(layout.getCTAsPerCGA(), order),
ctx, applyPermutation(layout.getCTAsPerCGA(), order),
applyPermutation(layout.getCTASplitNum(), order),
applyPermutation(invOrderUnsigned, layout.getCTAOrder()));
};
Expand All @@ -2559,9 +2560,9 @@ struct TritonGPUInferLayoutInterface
return failure();
}
resultEncoding = SharedEncodingAttr::get(
getDialect()->getContext(), enc.getVec(), enc.getPerPhase(),
enc.getMaxPhase(), applyPermutation(invOrderUnsigned, enc.getOrder()),
*ctaLayout, enc.getHasLeadingOffset());
ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(),
applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout,
enc.getHasLeadingOffset());
return success();
}

Expand All @@ -2577,15 +2578,27 @@ struct TritonGPUInferLayoutInterface
return failure();
}
resultEncoding = BlockedEncodingAttr::get(
getDialect()->getContext(),
applyPermutation(enc.getSizePerThread(), order),
ctx, applyPermutation(enc.getSizePerThread(), order),
applyPermutation(enc.getThreadsPerWarp(), order),
applyPermutation(enc.getWarpsPerCTA(), order),
applyPermutation(invOrderUnsigned, enc.getOrder()), *ctaLayout);
return success();
}

return failure(); // unhandled encoding
auto ll = toLinearLayout(shape, operandEncoding);
auto namedBases = ll.getBases();
for (auto &bases : llvm::make_second_range(namedBases)) {
for (auto &b : bases) {
std::vector<int32_t> newB;
for (auto i : order) {
newB.push_back(b[i]);
}
b = std::move(newB);
}
}
auto retLl = LinearLayout(std::move(namedBases),
llvm::to_vector(ll.getOutDimNames()));
resultEncoding = LinearEncodingAttr::get(ctx, std::move(retLl));
return success();
}

LogicalResult
Expand Down Expand Up @@ -2901,18 +2914,19 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const override {
LogicalResult
verifyLayoutsAreEqual(ArrayRef<int64_t> shape, Attribute expected,
Attribute got,
std::optional<Location> loc) const override {
if (expected == got) {
return success();
}
// Check whether the encodings are structurally the same.
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
auto gotLL = triton::gpu::toLinearLayout(shape, got);
if (expectedLL != gotLL) {
return emitError(loc, "Expected result encoding ")
<< expected << " but was " << got;
return emitOptionalError(loc, "Expected result encoding ", expected,
" but was ", got);
}
return success();
}
Expand Down
52 changes: 43 additions & 9 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ bool hasDotOperandEncoding(Value value) {
return hasEncoding<triton::gpu::DotOperandEncodingAttr>(value);
}

bool isConvertTrivial(ConvertLayoutOp op) {
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
auto srcEncoding = srcType.getEncoding();
auto dstEncoding = dstType.getEncoding();
return cast<DialectInferLayoutInterface>(&srcEncoding.getDialect())
->verifyLayoutsAreEqual(srcType.getShape(), srcEncoding, dstEncoding, {})
.succeeded();
}

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -43,26 +53,48 @@ struct CanonicalizeConvertFromReshape
if (!convert)
return failure();
// If the layouts are structurally the same, the convert is trivial
auto srcType = convert.getSrc().getType();
auto dstType = convert.getType();
auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding());
auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding());
if (srcLL == dstLL) {
if (isConvertTrivial(convert)) {
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder());
return mlir::success();
op, op.getType(), convert.getSrc(), op.getAllowReorder(),
op.getEfficientLayout());
return success();
}

if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();

rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder());
op, op.getType(), convert.getSrc(), op.getAllowReorder(),
op.getEfficientLayout());
return mlir::success();
}
};

// TODO We should do this generically for op(cvt) -> op
// We have similar patterns for reshape and split...
// See https://github.com/triton-lang/triton/pull/5403#discussion_r1920091671

// trans(cvt) -> trans
struct CanonicalizeConvertFromTranspose
: public mlir::OpRewritePattern<triton::TransOp> {
using OpRewritePattern::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(triton::TransOp op,
PatternRewriter &rewriter) const override {
// If the layouts are structurally the same, the convert is trivial
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert || !isConvertTrivial(convert))
return failure();

rewriter.replaceOpWithNewOp<triton::TransOp>(
op, op.getType(), convert.getSrc(), op.getOrder());
return success();
}
};

// histogram(cvt) -> histogram
struct CanonicalizeConvertFromHistogram
: public mlir::OpRewritePattern<triton::HistogramOp> {
Expand Down Expand Up @@ -289,6 +321,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<CanonicalizeConvertFromConvert>(context);
patterns.add<CanonicalizeConvertFromReshape>(context);
patterns.add<CanonicalizeConvertFromTranspose>(context);
patterns.add<CanonicalizeConvertFromGatherSource>(context);
patterns.add<CanonicalizeConvertFromHistogram>(context);
patterns.add<CanonicalizeConvertFromAlloc>(context);
Expand Down Expand Up @@ -435,6 +468,7 @@ LogicalResult MemDescTransOp::inferReturnTypes(
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the input
auto argTy = cast<MemDescType>(operands[0].getType());
auto argShape = argTy.getShape();
auto order = properties.as<Properties *>()->order.asArrayRef();
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);

Expand All @@ -445,7 +479,7 @@ LogicalResult MemDescTransOp::inferReturnTypes(
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
if (inferLayoutInterface
->inferTransOpEncoding(argEncoding, order, retEncoding)
->inferTransOpEncoding(argEncoding, argShape, order, retEncoding)
.failed()) {
return failure();
}
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,21 +368,24 @@ static Attribute inferSrcEncoding(GatherOp op, Attribute dstEnc) {
}

static Attribute inferTransOpDstEncoding(Attribute srcEnc,
ArrayRef<int64_t> shape,
ArrayRef<int32_t> order) {
// Simply forward to the existing inferTransOpEncoding function.
Attribute retEncoding;
if (succeeded(
srcEnc.getDialect()
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
->inferTransOpEncoding(srcEnc, order, retEncoding))) {
->inferTransOpEncoding(srcEnc, shape, order, retEncoding))) {
return retEncoding;
}
return {};
}

static Attribute inferDstEncoding(triton::TransposeOpInterface op,
Attribute encoding) {
return inferTransOpDstEncoding(encoding, op.getOrder());
return inferTransOpDstEncoding(
encoding, cast<RankedTensorType>(op.getSrc().getType()).getShape(),
op.getOrder());
}

static Attribute inferSrcEncoding(triton::TransposeOpInterface op,
Expand All @@ -393,7 +396,8 @@ static Attribute inferSrcEncoding(triton::TransposeOpInterface op,
// transpose(transpose(x, order), inverse(order)) == x,
// we can see this is equivalent to
// transpose(dstEnc, inverse(order)) -> srcEnc.
return inferTransOpDstEncoding(encoding,
auto shape = cast<RankedTensorType>(op->getResult(0).getType()).getShape();
return inferTransOpDstEncoding(encoding, shape,
triton::inversePermutation(op.getOrder()));
}

Expand Down
18 changes: 18 additions & 0 deletions test/TritonGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
tt.return %1 : tensor<16x16xf16, #blocked>
}
}

// -----

#linear = #ttg.linear<{register = [[0, 1], [8, 0], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[0, 8], [0, 16]], block = []}>
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked_trans = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [0, 1]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: @infer_trans
tt.func @infer_trans(%arg0: tensor<32x32xf32, #linear>) -> tensor<32x32xf32, #blocked_trans> {
// CHECK-NOT: ttg.convert_layout
%0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #linear> -> tensor<32x32xf32, #blocked>
%1 = tt.trans %0 {order = array<i32: 1, 0>} : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked_trans>
tt.return %1 : tensor<32x32xf32, #blocked_trans>
}

}

0 comments on commit 1d01b72

Please sign in to comment.