Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] rc/3.2.x cherry picks #5347

Merged
merged 11 commits into from
Dec 5, 2024
Merged
6 changes: 5 additions & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ namespace gpu {
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -179,8 +187,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -207,7 +215,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
61 changes: 0 additions & 61 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1388,67 +1388,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
7 changes: 0 additions & 7 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,6 @@ class LinearLayout {
// (i.e. every input bit affects the output).
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;

// Increase an input dimension without affecting the output dimension. The
// added free variables are mapped to 0, ensuring that the new input
// dimensions correspond directly to the existing output space. The function
// errors out if `newInDimSize` is less than the current size or the new size
// is not a power of 2.
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;

std::string toString() const;

friend bool operator==(LinearLayout lhs, LinearLayout rhs);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

assert(cvtNeedsSharedMemory(srcTy, dstTy));
assert(!isMfmaToDotShortcut(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
Expand Down
81 changes: 37 additions & 44 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
Expand Down Expand Up @@ -605,6 +605,22 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -639,46 +655,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
auto numDstRegs = dstLayout->getInDimSize(kRegister);
// The `invertAndCompose` function will generate a layout that is injective
// by assigning new output dimensions to free variables. For instance,
// consider a scenario where `srcLayout` has a free variable in the lane
// dimension, while `dstLayout` has two free variables in the lane
// dimension and also a larger number of registers.
// The injective form of `srcLayout` will add only a single additional row
// to the transformation matrix, whereas the injective form of `dstLayout`
// will add two additional rows. This discrepancy causes misleading results
// because the matrices end up with a different number of rows.
//
// Take `dstLayout ⋅ srcLayout^-1` as an example:
//
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
// 1] → [n + 2, n + 1]
//
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
// variable in registers, and the `(n + 2)`-th row represents the free
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
// in two layouts do not correspond to the same free variable.
//
// To address this issue, we pad the free variables in `srcLayout` and
// `dstLayout` to ensure they have the same number of registers. This
// guarantees that the resulting matrices have the same number of rows,
// ensuring consistency in the composition process.
auto numRegs = std::max(numSrcRegs, numDstRegs);
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
// comp describes the layout function to create dst from src.
LinearLayout comp =
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
Expand Down Expand Up @@ -715,14 +693,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
// subsumed by the linear-layout checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand All @@ -732,6 +711,20 @@ bool atomicNeedsSharedMemory(Value value) {
return true;
}

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
!srcTy.getElementType().isF32();
}

namespace {

/// A data structure similar to SetVector but maintains
Expand Down
Loading
Loading