Skip to content

Commit

Permalink
Change kMinor to kContig
Browse files Browse the repository at this point in the history
  • Loading branch information
Ognjen Plavsic committed Nov 18, 2024
1 parent 2d81cba commit c7dbd6d
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 26 deletions.
4 changes: 2 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ unsigned getNumCTAs(Attribute layout);
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);

// Return the order that represents that the dot operand is in kMinor
// Return the order that represents that the dot operand is in kContig
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMinor);
bool kContig);

bool isExpensiveCat(CatOp cat, Attribute targetEncoding);

Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMinor) {
// kMinor: if true, the matrix is fastest-running on k,
bool kContig) {
// kContig: if true, the matrix is fastest-running on k,
// otherwise it is on m (resp. n)
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
auto rowMajor = bool(opIdx) != kMinor;
auto rowMajor = bool(opIdx) != kContig;
return getMatrixOrder(rank, rowMajor);
}

Expand Down Expand Up @@ -290,7 +290,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMinor*/ true);
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kContig*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
Expand Down Expand Up @@ -1040,7 +1040,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMinor*/ true);
/*kContig*/ true);
}
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
ArrayRef<int64_t> tensorShape) const {
Expand Down Expand Up @@ -1665,7 +1665,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<int64_t>
Expand Down Expand Up @@ -1759,7 +1759,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
Expand Down Expand Up @@ -1969,7 +1969,7 @@ NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMinor*/ true);
return getOrderForDotOperand(opIdx, rank, /*kContig*/ true);
}

SmallVector<int64_t>
Expand Down
18 changes: 9 additions & 9 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,16 +875,16 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,

MLIRContext *ctx = mma.getContext();

// The A and B operands are tiled in a kMinor fashion
auto kMinorOrder = dot.getRepOrder();
assert(kMinorOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMinor=*/true));
// The A and B operands are tiled in a kContig fashion
auto kContigOrder = dot.getRepOrder();
assert(kContigOrder ==
getOrderForDotOperand(dot.getOpIdx(), rank, /*kContig=*/true));

auto kMinorDims =
permuteDimNames(standardOutDimNames(ctx, rank), kMinorOrder);
auto kContigDims =
permuteDimNames(standardOutDimNames(ctx, rank), kContigOrder);
// This agrees with the order of the elements, which means that we can share
// the code below for both A and B without having to perform any swaps
assert(getOrder(dot) == kMinorOrder);
assert(getOrder(dot) == kContigOrder);

std::vector<std::vector<int32_t>> registers;
std::vector<std::vector<int32_t>> lanes;
Expand All @@ -911,7 +911,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
registers.push_back({i, 0});

LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
ArrayRef(kMinorDims).take_front(2));
ArrayRef(kContigDims).take_front(2));

// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
Expand Down Expand Up @@ -952,7 +952,7 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
}
}

ctaLayout *= LinearLayout({{S("warp"), warps}}, kMinorDims);
ctaLayout *= LinearLayout({{S("warp"), warps}}, kContigDims);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
return base;
}

bool isKMinor(llvm::ArrayRef<unsigned> order, int opIdx) {
bool isKContig(llvm::ArrayRef<unsigned> order, int opIdx) {
auto rank = order.size();
int kdim = opIdx == 0 ? rank - 1 : rank - 2;
return order[0] == kdim;
Expand Down Expand Up @@ -106,9 +106,9 @@ bool isSwizzlePatternFitsIntoBlock(const SharedEncodingAttr sharedLayout,
const auto swizzleSlowDimSize =
sharedLayout.getMaxPhase() * sharedLayout.getPerPhase();
const auto swizzlePatternSizeK =
isKMinor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
const auto swizzlePatternSizeNonK =
!isKMinor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
!isKContig(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;

const auto blockSizeK = mfmaInstrK * reps[reps.size() - 1];
const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
const SharedMemoryObject &smemObj);

bool isKMinor(llvm::ArrayRef<unsigned> order, int opIdx);
bool isKContig(llvm::ArrayRef<unsigned> order, int opIdx);

using computeTensorElemMappingInBlockT =
std::function<llvm::SmallVector<llvm::SmallVector<Value>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
SmallVector<Value> offsets;
Value smemBase;
bool isFastPath =
!AMD::isKMinor(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
!AMD::isKContig(order, opIdx) && !hasSwizzleEnabled(sharedLayout);
if (isFastPath) {
// fast path handles tensors that are not k-major and have swizzling
// disabled, in which case offsets computation can be simplified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
// getValuesFromDotOperandLayoutStruct as both a and b are K-major
assert(dotOpA.getRepOrder() == getOrderForDotOperand(dotOpA.getOpIdx(),
aShapePerCTA.size(),
/*kMinor=*/true));
/*kContig=*/true));
auto ha = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy);

assert(dotOpB.getRepOrder() == getOrderForDotOperand(dotOpB.getOpIdx(),
bShapePerCTA.size(),
/*kMinor=*/true));
/*kContig=*/true));
auto hb = getValuesFromDotOperandLayoutStruct(
typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy);

Expand Down

0 comments on commit c7dbd6d

Please sign in to comment.