Skip to content

Commit

Permalink
[AMD] rc/3.2.x cherry picks (#5347)
Browse files Browse the repository at this point in the history
Reverts #5191 due to some mlir errors in pytorch unit tests

Smaller set of cherry picks:
- #5308 (and previous LLVM upgrades)
- #5281 
- #4925 
- #5053 
- #5019 
- #4998

---------

Co-authored-by: Jungwook Park <[email protected]>
Co-authored-by: peterbell10 <[email protected]>
Co-authored-by: Hongtao Yu <[email protected]>
Co-authored-by: Lei Zhang <[email protected]>
Co-authored-by: Ilya V <[email protected]>
Co-authored-by: Kyle Wang <[email protected]>
  • Loading branch information
7 people authored Dec 5, 2024
1 parent 2d8093c commit 7e401df
Show file tree
Hide file tree
Showing 26 changed files with 739 additions and 882 deletions.
6 changes: 5 additions & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ namespace gpu {
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter *typeConverter);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -179,8 +187,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -207,7 +215,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
61 changes: 0 additions & 61 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1388,67 +1388,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
7 changes: 0 additions & 7 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,6 @@ class LinearLayout {
// (i.e. every input bit affects the output).
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;

// Increase an input dimension without affecting the output dimension. The
// added free variables are mapped to 0, ensuring that the new input
// dimensions correspond directly to the existing output space. The function
// errors out if `newInDimSize` is less than the current size or the new size
// is not a power of 2.
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;

std::string toString() const;

friend bool operator==(LinearLayout lhs, LinearLayout rhs);
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

assert(cvtNeedsSharedMemory(srcTy, dstTy));
assert(!isMfmaToDotShortcut(srcTy, dstTy));

// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
Expand Down
81 changes: 37 additions & 44 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
Expand Down Expand Up @@ -605,6 +605,22 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}

// For MMAV3 dotOperand layout matches mma operand for f16 and bf16 cases.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Expand Down Expand Up @@ -639,46 +655,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
auto numDstRegs = dstLayout->getInDimSize(kRegister);
// The `invertAndCompose` function will generate a layout that is injective
// by assigning new output dimensions to free variables. For instance,
// consider a scenario where `srcLayout` has a free variable in the lane
// dimension, while `dstLayout` has two free variables in the lane
// dimension and also a larger number of registers.
// The injective form of `srcLayout` will add only a single additional row
// to the transformation matrix, whereas the injective form of `dstLayout`
// will add two additional rows. This discrepancy causes misleading results
// because the matrices end up with a different number of rows.
//
// Take `dstLayout ⋅ srcLayout^-1` as an example:
//
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
// 1] → [n + 2, n + 1]
//
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
// variable in registers, and the `(n + 2)`-th row represents the free
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
// in two layouts do not correspond to the same free variable.
//
// To address this issue, we pad the free variables in `srcLayout` and
// `dstLayout` to ensure they have the same number of registers. This
// guarantees that the resulting matrices have the same number of rows,
// ensuring consistency in the composition process.
auto numRegs = std::max(numSrcRegs, numDstRegs);
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
// comp describes the layout function to create dst from src.
LinearLayout comp =
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
Expand Down Expand Up @@ -715,14 +693,15 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
// subsumed by the linear-layout checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand All @@ -732,6 +711,20 @@ bool atomicNeedsSharedMemory(Value value) {
return true;
}

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
!srcTy.getElementType().isF32();
}

namespace {

/// A data structure similar to SetVector but maintains
Expand Down
Loading

0 comments on commit 7e401df

Please sign in to comment.