Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit f436c9e #3124

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ class DialectInferLayoutInterface

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
// elements as before the reshape using legacy layouts. This is not always
// possible (in which case we fallback to using LinearLayouts)
// In the future we'll always use LinearLayouts
virtual LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Check if two layouts are structurally the same, even if their names are
// different
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
Expand Down
58 changes: 51 additions & 7 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,60 @@
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"

// LinearLayoutCache Utils
using CacheKey =
std::tuple<std::vector<int64_t>, mlir::Attribute, std::optional<int32_t>>;

namespace llvm {
template <typename T> size_t hash_value(const std::vector<T> &vec) {
return hash_combine_range(vec.begin(), vec.end());
}
} // namespace llvm

namespace std {
template <> struct hash<CacheKey> {
size_t operator()(const CacheKey &key) const noexcept {
using llvm::hash_value;
size_t seed = 0;
std::apply(
[&seed](const auto &...elems) {
((seed = llvm::hash_combine(seed, hash_value(elems))), ...);
},
key);
return seed;
}
};
} // namespace std

namespace mlir::triton::gpu {

class LinearLayoutCache {
public:
std::optional<LinearLayout> get(const CacheKey &key) {
std::shared_lock lock(mutex);
auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}
return std::nullopt;
}

void set(CacheKey key, LinearLayout result) {
std::scoped_lock lock(mutex);
cache.emplace(std::move(key), std::move(result));
}

private:
std::unordered_map<CacheKey, LinearLayout> cache;
llvm::sys::SmartRWMutex<true> mutex;
};
} // namespace mlir::triton::gpu

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"

namespace mlir {
namespace triton {
namespace gpu {

namespace mlir::triton::gpu {
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};
Expand Down Expand Up @@ -240,8 +286,6 @@ llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace gpu
} // namespace triton
} // namespace mlir
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def TritonGPU_Dialect : Dialect {
}
return cast<IntegerAttr>(threadsPerWarp).getInt();
}

std::optional<LinearLayout>
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
std::optional<int32_t> elemBitWidth);

private:
LinearLayoutCache llCache;
}];

let useDefaultTypePrinterParser = 1;
Expand Down
7 changes: 3 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,14 @@ bool ReduceOpHelper::isSupportedLayout() {
}

auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr>(srcLayout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
srcLayout)) {
return true;
}

if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}

Expand Down
30 changes: 14 additions & 16 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
Expand Down Expand Up @@ -701,24 +702,21 @@ LogicalResult ReshapeOp::verify() {
"encodings, or (b) neither does.");
}

if (srcEnc && !getAllowReorder()) {
Attribute inferredDstEnc;
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
dstTy.getShape(), inferredDstEnc,
getLoc())
.failed()) {
return emitError("This reshape is impossible without reordering, but "
"reordering is not allowed. Try choosing a different "
"encoding for the input tensor (or allow reordering).");
}
if (inferredDstEnc != dstEnc) {
return emitError("Expected result encoding ")
<< inferredDstEnc << " but was " << dstEnc;
}
if (!srcEnc || getAllowReorder()) {
return success();
}

return success();
// Check that we can infer the dst encoding from the src encoding
// and that the inferred dst encoding is the same as the given dst encoding
Attribute inferredDstEnc;
auto result =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
inferredDstEnc, getLoc());
assert(succeeded(result));
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
getLoc());
}

//-- FpToFpOp --
Expand Down
125 changes: 83 additions & 42 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,

SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = -1;
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
Expand All @@ -1482,7 +1482,6 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
assert(nonZeroIdx != -1);
ret[nonZeroIdx] *= 2;
}
}
Expand Down Expand Up @@ -1627,12 +1626,14 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// We can relax this assert by calling toLinearLayout rather than
// getLinearLayout
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
auto ll = getLinearLayout();
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
/*skipBroadcast=*/false);
}

// Start of Selection
Expand Down Expand Up @@ -2705,8 +2706,8 @@ struct TritonGPUInferLayoutInterface
// contains elements [a,b,c,d] before the reshape, it contains those same
// elements after the reshape, they're just "renamed".
//
// A dst encoding that satisfies this property does not exist for all inputs.
// Here are some positive and negative examples.
// Using legacy layouts, a dst encoding that satisfies this property may not
// exist. Here are some positive and negative examples.
//
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
// dim 1 is the fastest-changing in the dst, but the src has the opposite
Expand All @@ -2720,17 +2721,19 @@ struct TritonGPUInferLayoutInterface
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
// contain the same elements as before.
//
// With linear layouts, we can always find a dst encoding that satisfies
// this property. See inferReshapeOpEncoding.
//
// Users of this function require that it is symmetrical: if
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
// srcEnc.
LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) const {
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!src) {
return emitOptionalError(
loc, "Non-reordering reshape only supports BlockedEncoding");
return failure();
}

// Nop reshape; we can always infer an encoding.
Expand Down Expand Up @@ -2763,9 +2766,7 @@ struct TritonGPUInferLayoutInterface
// to handle CTASplitNum.
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
return emitOptionalError(
loc, "Non-reordering reshape does not currently support multi-CTA "
"layouts other than the default layout.");
return failure();
}

// Cowardly refuse to handle encodings where shape[dim] is not divisible by
Expand All @@ -2775,12 +2776,7 @@ struct TritonGPUInferLayoutInterface
for (int dim = 0; dim < srcShape.size(); dim++) {
if (srcShape[dim] >= subblock[dim] &&
srcShape[dim] % subblock[dim] != 0) {
return emitOptionalError(loc,
"Can't do a non-reordering reshape because "
"the size of dimension ",
dim, " (", srcShape[dim], ")",
" is not divisible by ", name, "[", dim, "]",
" = ", subblock[dim]);
return failure();
}
}
return success();
Expand All @@ -2805,11 +2801,7 @@ struct TritonGPUInferLayoutInterface
// physical order, with `a` being the most major.
for (const auto &[srcDims, dstDims] : decomp) {
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
return emitOptionalError(loc,
"Cannot do a non-reordering reshape given "
"this src encoding order. Dimensions [",
join(srcDims),
"] must be physically consecutive.");
return failure();
}
}

Expand Down Expand Up @@ -2856,11 +2848,7 @@ struct TritonGPUInferLayoutInterface
// Check that more-minor dims all have 1 in shapeRemaining.
for (int j = i + 1; j < srcDims.size(); j++) {
if (shapeRemaining[j] != 1) {
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Must use "
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
"more-minor dimensions before more major-dims can use them.");
return failure();
}
}

Expand All @@ -2875,13 +2863,7 @@ struct TritonGPUInferLayoutInterface
// only if we're the most-major dimension of the chunk and in all
// future chunks, only this most-major dim has a non-1 size.
if (shapeRemaining[i] == 0 && i != 0) {
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Block "
"size in dimension ",
dim,
" is larger than the shape that dimension, but this is only "
"allowed for the most-major dimension of a reshape chunk");
return failure();
}
}
return success();
Expand Down Expand Up @@ -2971,6 +2953,65 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const override {
if (expected == got) {
return success();
}
// Check whether the encodings are structurally the same.
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
auto gotLL = triton::gpu::toLinearLayout(shape, got);
if (expectedLL != gotLL) {
return emitError(loc, "Expected result encoding ")
<< expected << " but was " << got;
}
return success();
}

LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
return result;
}

// If the legacy encoding failed use LinearLayouts.
// Once LinearLayouts are more widely used, we can remove
// inferReshapeOpLegacyEncoding and simply use LLs.
auto *ctx = getContext();
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
if (!src) {
return emitOptionalError(loc,
"src encoding does not support linear layout");
}

if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}

auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
// before the reshape
std::reverse(srcOutDims.begin(), srcOutDims.end());
std::reverse(newOutDims.begin(), newOutDims.end());
auto dst = src->transposeOuts(srcOutDims)
.reshapeOuts(newOutDims)
.transposeOuts(standardOutDimNames(ctx, newRank));
dstEnc = LinearEncodingAttr::get(ctx, dst);
return success();
}

LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const override {
Expand Down
Loading
Loading