Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generalization for DotLike operations #20

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
86b69c31642e98f8357df62c09d118ad1da4e16a
1f20eee6dc367bd202895e3eedb03974a628ef16
24 changes: 20 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,27 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutable_memory
"bool":$mutableMemory,
ArrayRefParameter<"int64_t">:$allocShape
);

let extraClassDeclaration = [{
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape());
}

bool hasRank() const { return true; }
}];

let builders = [
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
Expand All @@ -75,10 +78,23 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
"Attribute":$memorySpace,
"bool":$mutableMemory
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutableMemory,
"llvm::ArrayRef<int64_t>":$allocShape
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape);
}]>

];

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}


Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand All @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;

AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
// Encoding attributes
if (auto mmaAttr = mlir::dyn_cast<MmaEncodingTrait>(attr)) {
os << "mma";
return AliasResult::FinalAlias;
Expand All @@ -2475,6 +2476,11 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
os << "slice";
return AliasResult::FinalAlias;
} */
// Memory space attributes
if (auto smem = mlir::dyn_cast<SharedMemorySpaceAttr>(attr)) {
os << "smem";
return AliasResult::FinalAlias;
}
return OpAsmDialectInterface::getAlias(attr, os);
}
};
Expand Down
72 changes: 48 additions & 24 deletions lib/Dialect/TritonGPU/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,47 +30,54 @@ void TokenType::print(AsmPrinter &printer) const {
static constexpr llvm::StringRef kMutableMemory = "mutable";

Type MemDescType::parse(AsmParser &parser) {
if (parser.parseLess())
if (failed(parser.parseLess()))
return Type();

SmallVector<int64_t> dimensions;
if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false))
SmallVector<int64_t> dimensions; // required
if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false)))
return Type();

// Parse the element type.
Type elementType;
if (parser.parseType(elementType))
Type elementType; // required
if (failed(parser.parseType(elementType)))
return Type();

Attribute encoding;
if (succeeded(parser.parseOptionalComma())) {
if (parser.parseAttribute(encoding))
return Type();
}
bool mutableMemory = false;
Attribute memorySpace;
Attribute encoding; // required
if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding)))
return Type();

Attribute memorySpace; // required
if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace)))
return Type();

bool mutableMemory = false; // optional
SmallVector<int64_t> allocShape; // optional
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseOptionalKeyword(kMutableMemory))) {
if (parser.parseAttribute(memorySpace))
return Type();
} else {
if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) {
mutableMemory = true;
}
}
if (mutableMemory == false && succeeded(parser.parseOptionalComma())) {
if (parser.parseOptionalKeyword(kMutableMemory))
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false,
/*withTrailingX=*/false))) {
return Type();
}
}
} else if (failed(parser.parseDimensionList(allocShape,
/*allowDynamic=*/false,
/*withTrailingX=*/false))) {
return Type();
mutableMemory = true;
}
}

if (parser.parseGreater())
return Type();

return MemDescType::get(parser.getContext(), dimensions, elementType,
encoding, memorySpace, mutableMemory);
encoding, memorySpace, mutableMemory, dimensions);
}

void MemDescType::print(AsmPrinter &printer) const {
printer << "<";
for (auto dim : getShape())
auto shape = getShape();
for (auto dim : shape)
printer << dim << "x";
printer << getElementType();
if (getEncoding())
Expand All @@ -79,9 +86,26 @@ void MemDescType::print(AsmPrinter &printer) const {
printer << ", " << getMemorySpace();
if (getMutableMemory())
printer << ", " << kMutableMemory;
auto allocShape = getAllocShape();
if (allocShape != shape) {
printer << ", " << allocShape[0];
for (auto dim : allocShape.drop_front(1)) {
printer << "x" << dim;
}
}
printer << ">";
}

LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
Attribute encoding, Attribute memorySpace,
bool mutableMemory,
ArrayRef<int64_t> allocShape) {
if (allocShape.size() < shape.size())
emitError() << "alloc shape must have at least as many dimensions as shape";
return success();
}

//===----------------------------------------------------------------------===//
// Triton Dialect
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 8 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
return 0;
}

SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned>
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
Expand All @@ -58,9 +58,8 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
auto slices = multiRootGetSlice(dotOp, {filter}, {filter});
bool hasChainedDot = false;
for (Operation *op : slices) {
if (isa<DotOp>(op) && (op != dotOp)) {
auto chainedDot = cast<DotOp>(op);
auto resTy = chainedDot.getResult().getType();
if (op->hasTrait<OpTrait::DotLike>() && op != dotOp) {
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
if (resTy.getRank() != rank) {
continue;
}
Expand Down Expand Up @@ -109,14 +108,14 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}

SmallVector<unsigned, 2>
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
mlir::getForwardSlice(dotOp->getResult(0), &slices);
// Contains a chained dot. We prefer to assign warps to one axis
// to facilitate use cases like flash attention, allowing reductions within
// the same warp.
if (llvm::find_if(slices, [](Operation *op) {
if (llvm::find_if(slices, [&](Operation *op) {
return op->hasTrait<OpTrait::DotLike>();
}) != slices.end())
return {(unsigned)numWarps, 1};
Expand Down Expand Up @@ -171,7 +170,7 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
}

SmallVector<unsigned, 3>
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
getWarpsPerTile(Operation *dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
Expand Down
35 changes: 19 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
ttg::MemDescType subviewTy = ttg::MemDescType::get(
allocTy.getShape().drop_front(), allocTy.getElementType(),
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true,
/*allocShape=*/allocTy.getAllocShape());
auto view = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stage, clusterId, subviewTy, alloc, copyOffsets);
Operation *copy = builder.createWithStage<ttg::AsyncCopyGlobalToLocalOp>(
Expand Down Expand Up @@ -232,7 +233,8 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
copyOffsets[0] = insertIdx;
ttg::MemDescType subviewTy = ttg::MemDescType::get(
allocTy.getShape().drop_front(), allocTy.getElementType(),
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true,
/*allocShape=*/allocTy.getAllocShape());
auto view = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stage, clusterId, subviewTy, alloc, copyOffsets);

Expand Down Expand Up @@ -526,7 +528,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
bufferShape.insert(bufferShape.begin(), distance);
Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(),
sharedEnc, sharedMemorySpace,
/*mutableMemory*/ true);
/*mutableMemory=*/true);
Value alloc =
builder.create<ttg::LocalAllocOp>(loadOp->getLoc(), memdescType, Value());
return alloc;
Expand All @@ -544,12 +546,13 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
auto barrierEncoding =
ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout);
Type barrierMemDescType = ttg::MemDescType::get(
auto barrierMemDescType = ttg::MemDescType::get(
{distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
/*mutableMemory=*/true);
Type singleBarrierMemDescType =
ttg::MemDescType::get({1}, builder.getI64Type(), barrierEncoding,
sharedMemorySpace, /*mutableMemory=*/true);
Type singleBarrierMemDescType = ttg::MemDescType::get(
{1}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
/*mutableMemory=*/true,
/*allocShape=*/barrierMemDescType.getAllocShape());
Value barrierAlloc =
builder.create<ttg::LocalAllocOp>(loc, barrierMemDescType, Value());
for (unsigned i = 0; i < distance; i++) {
Expand Down Expand Up @@ -650,11 +653,11 @@ static void createTMABarrierAndWait(
OpBuilderWithStage builder(forOp);
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(builder.getContext());
auto allocTy = cast<ttg::MemDescType>(barrierAlloc.getType());
ttg::MemDescType barrierTy = ttg::MemDescType::get(
{1}, builder.getI64Type(),
cast<ttg::MemDescType>(barrierAlloc.getType()).getEncoding(),
sharedMemorySpace,
/*mutableMemory=*/true);
{1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace,
/*mutableMemory=*/true,
/*allocShape=*/allocTy.getAllocShape());
builder.setInsertionPoint(group[0]->loadOp);
Value barrier = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stage, cluster, barrierTy, barrierAlloc,
Expand Down Expand Up @@ -835,14 +838,14 @@ static void invalidateBarriers(OpBuilder &builder,
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(builder.getContext());
for (Value barrier : barriers) {
int numBarriers = cast<ttg::MemDescType>(barrier.getType()).getShape()[0];
auto allocTy = cast<ttg::MemDescType>(barrier.getType());
int numBarriers = allocTy.getShape()[0];
for (int i = 0; i < numBarriers; i++) {
Value idx = builder.create<arith::ConstantIntOp>(barrier.getLoc(), i, 32);
ttg::MemDescType barrierTy = ttg::MemDescType::get(
{1}, builder.getI64Type(),
cast<ttg::MemDescType>(barrier.getType()).getEncoding(),
sharedMemorySpace,
/*mutableMemory=*/true);
{1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace,
/*mutableMemory=*/true,
/*allocShape=*/allocTy.getShape());
Value barrierView = builder.create<ttg::MemDescSubviewOp>(
barrier.getLoc(), barrierTy, barrier, idx);
builder.create<ttng::InvalBarrierOp>(barrier.getLoc(), barrierView);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
storeOp.getMaskMutable().assign(mask);
return op;
}
if (auto atomicRMWOp = dyn_cast<tt::AtomicRMWOp>(op)) {
rewriter.setInsertionPoint(atomicRMWOp);
Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(),
atomicRMWOp.getMask(), pred);
atomicRMWOp.getMaskMutable().assign(mask);
return op;
}

assert("don't know how to predicate this op" && false);
return op;
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
builder.create<arith::ConstantIntOp>(v.getLoc(), off, 32));
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
v.getLoc(),
triton::gpu::MemDescType::get(shape, elementType, type.getEncoding(),
type.getMemorySpace()),
triton::gpu::MemDescType::get(
shape, elementType, type.getEncoding(), type.getMemorySpace(),
type.getMutableMemory(), type.getAllocShape()),
v, offsetsVal);

auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct FenceInsertionPass
return;
ModuleOp mod = getOperation();
mod.walk([&](Operation *op) {
if (!isa<ttng::WarpGroupDotOp>(op))
if (!op->hasTrait<OpTrait::DotLike>())
return WalkResult::advance();
OpBuilder builder(op);
auto a = op->getOperand(0);
Expand Down
Loading
Loading