From 8d42d211841b4241a08d9d0d2bb6b77fe6e261c0 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 3 Dec 2024 11:45:39 -0500 Subject: [PATCH 1/7] [IR] Improve `ttg.memdesc` (#5296) - Add an `allocShape` field to denote the shape a memory descriptor when it's allocated. The value will be propagated to all its descendants created through `subview` ops. - Make `encoding` and `memorySpace` fields required instead of optional. - Implement the `getAlias` function for `#ttg.shared_memory` to shorten its length in `.mlir` files --- .../Dialect/TritonGPU/IR/TritonGPUTypes.td | 24 ++- lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 + lib/Dialect/TritonGPU/IR/Types.cpp | 72 ++++++--- .../Pipeliner/MatmulLoopPipeline.cpp | 35 +++-- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 5 +- python/test/unit/language/test_core.py | 15 +- test/Conversion/amd/compute-base-ptr.mlir | 5 +- .../decompose-unsupported-conversions.mlir | 6 +- test/Conversion/amd/tritongpu_to_llvm.mlir | 11 +- .../amd/tritongpu_wmma_dot_to_llvm.mlir | 20 +-- test/Conversion/tritongpu_to_llvm.mlir | 120 +++++++++------ test/Conversion/tritongpu_to_llvm_hopper.mlir | 45 +++--- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 25 +-- test/Triton/invalid.mlir | 7 +- test/TritonGPU/accumulator-init.mlir | 61 ++++---- .../amd/amd-reorder-instructions.mlir | 143 +++++++++--------- test/TritonGPU/amd/amd-sched-2nd-load.mlir | 64 ++++---- test/TritonGPU/amd/optimize-lds-usage.mlir | 34 +++-- test/TritonGPU/canonicalize.mlir | 38 ++--- test/TritonGPU/coalesce-async-copy.mlir | 10 +- test/TritonGPU/combine.mlir | 19 ++- test/TritonGPU/dot-operands.mlir | 45 +++--- test/TritonGPU/fence-inserstion.mlir | 14 +- test/TritonGPU/invalid.mlir | 52 +++++-- test/TritonGPU/loop-pipeline-cuda.mlir | 27 ++-- test/TritonGPU/loop-pipeline-hip.mlir | 21 +-- ... => loop-pipeline-hopper-remove-wait.mlir} | 13 +- test/TritonGPU/loop-pipeline-hopper.mlir | 92 ++++++----- test/TritonGPU/loop-pipeline.mlir | 18 ++- test/TritonGPU/ops.mlir | 12 ++ test/TritonGPU/prefetch.mlir | 53 ++++--- test/TritonGPU/reduce-data-duplication.mlir | 2 +- test/TritonGPU/reorder-instructions.mlir | 50 +++--- .../samples/descriptor-matmul-pipeline.mlir | 85 ++++++----- test/TritonGPU/tritongpu_ops.mlir | 11 -- test/TritonNvidiaGPU/membar.mlir | 36 +++-- 36 files changed, 745 insertions(+), 551 deletions(-) rename test/TritonGPU/{pipeline-hopper-remove-wait.mlir => loop-pipeline-hopper-remove-wait.mlir} (97%) delete mode 100644 test/TritonGPU/tritongpu_ops.mlir diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td index 766d5a9bd713..8061a98797b7 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td @@ -49,16 +49,19 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { "Type":$elementType, "Attribute":$encoding, "Attribute":$memorySpace, - "bool":$mutable_memory + "bool":$mutableMemory, + ArrayRefParameter<"int64_t">:$allocShape ); + let extraClassDeclaration = [{ MemDescType cloneWith(std::optional> shape, Type elementType) const { - return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory()); + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape()); } bool hasRank() const { return true; } }]; + let builders = [ TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, @@ -66,7 +69,7 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { "Attribute":$encoding, "Attribute":$memorySpace ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape); }]>, TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, @@ -75,10 +78,23 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> { "Attribute":$memorySpace, "bool":$mutableMemory ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape); + }]>, + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$shape, + "Type":$elementType, + "Attribute":$encoding, + "Attribute":$memorySpace, + "bool":$mutableMemory, + "llvm::ArrayRef":$allocShape + ), [{ + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape); }]> + ]; + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 0799bd6df18a..df2c634f9e91 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2459,6 +2459,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + // Encoding attributes if (auto mmaAttr = mlir::dyn_cast(attr)) { os << "mma"; return AliasResult::FinalAlias; @@ -2475,6 +2476,11 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { os << "slice"; return AliasResult::FinalAlias; } */ + // Memory space attributes + if (auto smem = mlir::dyn_cast(attr)) { + os << "smem"; + return AliasResult::FinalAlias; + } return OpAsmDialectInterface::getAlias(attr, os); } }; diff --git a/lib/Dialect/TritonGPU/IR/Types.cpp b/lib/Dialect/TritonGPU/IR/Types.cpp index fe87626203f9..ef9c6c4a3067 100644 --- a/lib/Dialect/TritonGPU/IR/Types.cpp +++ b/lib/Dialect/TritonGPU/IR/Types.cpp @@ -30,47 +30,54 @@ void TokenType::print(AsmPrinter &printer) const { static constexpr llvm::StringRef kMutableMemory = "mutable"; Type MemDescType::parse(AsmParser &parser) { - if (parser.parseLess()) + if (failed(parser.parseLess())) return Type(); - SmallVector dimensions; - if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) + SmallVector dimensions; // required + if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false))) return Type(); - // Parse the element type. - Type elementType; - if (parser.parseType(elementType)) + Type elementType; // required + if (failed(parser.parseType(elementType))) return Type(); - Attribute encoding; - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(encoding)) - return Type(); - } - bool mutableMemory = false; - Attribute memorySpace; + Attribute encoding; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding))) + return Type(); + + Attribute memorySpace; // required + if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace))) + return Type(); + + bool mutableMemory = false; // optional + SmallVector allocShape; // optional if (succeeded(parser.parseOptionalComma())) { - if (failed(parser.parseOptionalKeyword(kMutableMemory))) { - if (parser.parseAttribute(memorySpace)) - return Type(); - } else { + if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) { mutableMemory = true; - } - } - if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { - if (parser.parseOptionalKeyword(kMutableMemory)) + if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false, + /*withTrailingX=*/false))) { + return Type(); + } + } + } else if (failed(parser.parseDimensionList(allocShape, + /*allowDynamic=*/false, + /*withTrailingX=*/false))) { return Type(); - mutableMemory = true; + } } + if (parser.parseGreater()) return Type(); + return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, memorySpace, mutableMemory); + encoding, memorySpace, mutableMemory, dimensions); } void MemDescType::print(AsmPrinter &printer) const { printer << "<"; - for (auto dim : getShape()) + auto shape = getShape(); + for (auto dim : shape) printer << dim << "x"; printer << getElementType(); if (getEncoding()) @@ -79,9 +86,26 @@ void MemDescType::print(AsmPrinter &printer) const { printer << ", " << getMemorySpace(); if (getMutableMemory()) printer << ", " << kMutableMemory; + auto allocShape = getAllocShape(); + if (allocShape != shape) { + printer << ", " << allocShape[0]; + for (auto dim : allocShape.drop_front(1)) { + printer << "x" << dim; + } + } printer << ">"; } +LogicalResult MemDescType::verify(function_ref emitError, + ArrayRef shape, Type elementType, + Attribute encoding, Attribute memorySpace, + bool mutableMemory, + ArrayRef allocShape) { + if (allocShape.size() < shape.size()) + emitError() << "alloc shape must have at least as many dimensions as shape"; + return success(); +} + //===----------------------------------------------------------------------===// // Triton Dialect //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index b370704be6bc..9aa6a8f8d3fa 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -144,7 +144,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); ttg::MemDescType subviewTy = ttg::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); auto view = builder.createWithStage( loc, stage, clusterId, subviewTy, alloc, copyOffsets); Operation *copy = builder.createWithStage( @@ -232,7 +233,8 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, copyOffsets[0] = insertIdx; ttg::MemDescType subviewTy = ttg::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); auto view = builder.createWithStage( loc, stage, clusterId, subviewTy, alloc, copyOffsets); @@ -526,7 +528,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, bufferShape.insert(bufferShape.begin(), distance); Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, - /*mutableMemory*/ true); + /*mutableMemory=*/true); Value alloc = builder.create(loadOp->getLoc(), memdescType, Value()); return alloc; @@ -544,12 +546,13 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); - Type barrierMemDescType = ttg::MemDescType::get( + auto barrierMemDescType = ttg::MemDescType::get( {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, /*mutableMemory=*/true); - Type singleBarrierMemDescType = - ttg::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, - sharedMemorySpace, /*mutableMemory=*/true); + Type singleBarrierMemDescType = ttg::MemDescType::get( + {1}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true, + /*allocShape=*/barrierMemDescType.getAllocShape()); Value barrierAlloc = builder.create(loc, barrierMemDescType, Value()); for (unsigned i = 0; i < distance; i++) { @@ -650,11 +653,11 @@ static void createTMABarrierAndWait( OpBuilderWithStage builder(forOp); Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(builder.getContext()); + auto allocTy = cast(barrierAlloc.getType()); ttg::MemDescType barrierTy = ttg::MemDescType::get( - {1}, builder.getI64Type(), - cast(barrierAlloc.getType()).getEncoding(), - sharedMemorySpace, - /*mutableMemory=*/true); + {1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true, + /*allocShape=*/allocTy.getAllocShape()); builder.setInsertionPoint(group[0]->loadOp); Value barrier = builder.createWithStage( loc, stage, cluster, barrierTy, barrierAlloc, @@ -835,14 +838,14 @@ static void invalidateBarriers(OpBuilder &builder, Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(builder.getContext()); for (Value barrier : barriers) { - int numBarriers = cast(barrier.getType()).getShape()[0]; + auto allocTy = cast(barrier.getType()); + int numBarriers = allocTy.getShape()[0]; for (int i = 0; i < numBarriers; i++) { Value idx = builder.create(barrier.getLoc(), i, 32); ttg::MemDescType barrierTy = ttg::MemDescType::get( - {1}, builder.getI64Type(), - cast(barrier.getType()).getEncoding(), - sharedMemorySpace, - /*mutableMemory=*/true); + {1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true, + /*allocShape=*/allocTy.getShape()); Value barrierView = builder.create( barrier.getLoc(), barrierTy, barrier, idx); builder.create(barrier.getLoc(), barrierView); diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 9ad71270c2b0..c11f2f8e5ee7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, builder.create(v.getLoc(), off, 32)); Value newSmem = builder.create( v.getLoc(), - triton::gpu::MemDescType::get(shape, elementType, type.getEncoding(), - type.getMemorySpace()), + triton::gpu::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()), v, offsetsVal); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 91be7d217d84..2daa8aaf07d6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5328,20 +5328,22 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t layouts = f""" #src = {src_layout} #dst = {dst_layout} + #smem = #ttg.shared_memory """ if interm_layout is None else f""" #src = {src_layout} #interm = {interm_layout} #dst = {dst_layout} + #smem = #ttg.shared_memory """ conversion = f""" %12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory> - %16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xi32, #src> - %17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory> - %18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xf16, #src> + %15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #smem> + %16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #smem> -> tensor<{M}x{N}xi32, #src> + %17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #smem> + %18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #smem> -> tensor<{M}x{N}xf16, #src> %12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> @@ -5405,6 +5407,7 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: layouts = f""" #dist = {dist_layout} #shared = {shared_layout} + #smem = #ttg.shared_memory """ ir = layouts + f""" module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ @@ -5433,8 +5436,8 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: %17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist> %18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr, #dist>, tensor<{M}x{N}x{K}xi32, #dist> %19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr, #dist> - %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> - %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist> + %20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> + %21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> -> tensor<{M}x{N}x{K}xi32, #dist> %22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> %23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> %24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist> diff --git a/test/Conversion/amd/compute-base-ptr.mlir b/test/Conversion/amd/compute-base-ptr.mlir index 84b0ffce2eec..4c74e95d8ad0 100644 --- a/test/Conversion/amd/compute-base-ptr.mlir +++ b/test/Conversion/amd/compute-base-ptr.mlir @@ -3,14 +3,15 @@ #blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> #shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @local_load_offset tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) { %0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1) - %1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> loc(#loc2) + %1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> loc(#loc2) // This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type. // CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0 - %2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) + %2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3) tt.return } } diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index c0d4ea1edbda..983d16e8d6bf 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -5,10 +5,11 @@ // CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}> // CHECK-LABEL: wmma_to_wmma_dot_op #mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #ttg.shared_memory> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #smem> // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return @@ -22,10 +23,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}> // CHECK-LABEL: wmma_to_wmma_dot3d_op #mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) { // CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #ttg.shared_memory> + // CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #smem> // CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>> %0 = ttg.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index bd0e86bc1aaf..8f4fbee399b4 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -40,23 +40,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> #dotop1 = #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: small_mfma_tensor_conversions tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr, #mfma>) { // CHECK-NOT: ttg.convert_layout - %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !ttg.memdesc<16x16xf16, #shared, #smem> // CHECK-4: store {{.*}} vector<4xf16> - %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #dotop0> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #dotop0> // CHECK-2: load {{.*}} vector<4xf16> - %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #dotop1> + %2 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #dotop1> // CHECK-8: load {{.*}} vector<1xf16> - %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #mfma> + %3 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #mfma> // CHECK-4: load {{.*}} vector<4xf16> %4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma> %5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma> // Store result to prevent DCE from removing all conversion related code - %6 = ttg.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !ttg.memdesc<16x16xf32, #shared, #ttg.shared_memory> + %6 = ttg.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !ttg.memdesc<16x16xf32, #shared, #smem> tt.return } } diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index ecb1c2c32dc4..68eb76afdb72 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -4,24 +4,25 @@ #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> #mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> #mma2 = #ttg.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma1_dot_operand - tt.func @wmma1_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared>) { + tt.func @wmma1_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared, #smem>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-128: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } // CHECK-LABEL: wmma2_dot_operand - tt.func @wmma2_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared>) { + tt.func @wmma2_dot_operand(%arg0: !ttg.memdesc<64x64xf16, #shared, #smem>) { // 2 CTA * 4 rep * load_per_thread_per_instr // CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16> - %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 8}>> // CHECK-COUNT-64: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 8}>> tt.return } @@ -168,13 +169,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0], hasLeadingOffset = false}> #mma1 = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 1, 4]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_dot_operand3d - tt.func @wmma_dot_operand3d(%arg0: !ttg.memdesc<4x16x32xf16, #shared>) { + tt.func @wmma_dot_operand3d(%arg0: !ttg.memdesc<4x16x32xf16, #shared, #smem>) { // CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16> - %0 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + %0 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared, #smem> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> // CHECK-COUNT-32: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xf16> - %1 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> + %1 = ttg.local_load %arg0 : !ttg.memdesc<4x16x32xf16, #shared, #smem> -> tensor<4x16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 17c17a0bee14..32dff6e49f83 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -440,6 +440,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #shared0 = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_alloc_tensor @@ -447,7 +448,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.mlir.constant - %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared0, #smem, mutable> tt.return } } @@ -455,6 +456,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #shared0 = #ttg.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_subview @@ -477,8 +479,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 - %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16x32xf32, #shared0, #ttg.shared_memory, mutable> - %1 = ttg.memdesc_subview %0[%index, %zero, %zero] : !ttg.memdesc<128x16x32xf32, #shared0, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #ttg.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> + %1 = ttg.memdesc_subview %0[%index, %zero, %zero] : !ttg.memdesc<128x16x32xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable> tt.return } } @@ -500,6 +502,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #slice1d0 = #ttg.slice<{dim = 0, parent = #blocked1}> #shared1D = #ttg.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = true}> #shared2D = #ttg.shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: basic_insert_slice_async_1d tt.func @basic_insert_slice_async_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { @@ -509,10 +512,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> - %71 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared2D, #ttg.shared_memory, mutable> + %71 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> %subview = ttg.memdesc_subview %71[%c0_i32, %c0_i32] : - !ttg.memdesc<2x64xi64, #shared2D, #ttg.shared_memory, mutable> -> - !ttg.memdesc<64xi64, #shared1D, #ttg.shared_memory, mutable> + !ttg.memdesc<2x64xi64, #shared2D, #smem, mutable> -> + !ttg.memdesc<64xi64, #shared1D, #smem, mutable> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 @@ -523,7 +526,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.commit_group - %73 = ttg.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !ttg.memdesc<64xi64, #shared1D, #ttg.shared_memory, mutable> + %73 = ttg.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr, #slice1d0> -> !ttg.memdesc<64xi64, #shared1D, #smem, mutable> ttg.async_commit_group %73 tt.return } @@ -539,6 +542,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> #AL = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 tt.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { @@ -556,14 +560,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> - %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x64xf32, #A, #ttg.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x64xf32, #A, #smem, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;" // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !ttg.memdesc<16x64xf32, #A, #ttg.shared_memory, mutable> + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable> ttg.async_commit_group tt.return } @@ -579,6 +583,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> #AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1 tt.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { @@ -596,7 +601,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> - %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf32, #A, #ttg.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<16x32xf32, #A, #smem, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm @@ -609,7 +614,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !ttg.memdesc<16x32xf32, #A, #ttg.shared_memory, mutable> + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !ttg.memdesc<16x32xf32, #A, #smem, mutable> ttg.async_commit_group tt.return } @@ -624,6 +629,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #slice3d0 = #ttg.slice<{dim = 0, parent=#block3}> #AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #A = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1_multictas tt.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { @@ -641,7 +647,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> - %tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #ttg.shared_memory, mutable> + %tensor = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #A, #smem, mutable> %index = arith.constant 1 : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 @@ -665,7 +671,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !ttg.memdesc<32x32xf32, #A, #ttg.shared_memory, mutable> + %a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !ttg.memdesc<32x32xf32, #A, #smem, mutable> ttg.async_commit_group tt.return } @@ -770,17 +776,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { #mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { - %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> - %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> + %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> // CHECK: llvm.inline_asm // CHECK: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> -> tensor<16x16xf16, #dot_operand_a> - %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> -> tensor<16x16xf16, #dot_operand_b> + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -809,15 +816,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { #mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma0, kWidth=4}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma0, kWidth=4}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot_fp8 tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { - %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> - %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> + %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> + %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> -> tensor<16x16xf8E5M2, #dot_operand_a> - %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #ttg.shared_memory> -> tensor<16x16xf8E5M2, #dot_operand_b> + %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> + %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -1046,6 +1054,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // ----- #blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: convert_layout_blocked_shared @@ -1054,7 +1063,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: !llvm.ptr<3> - %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> tt.return } } @@ -1109,13 +1118,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory>, %b:!ttg.memdesc<32x256xf16, #shared, #ttg.shared_memory>) { + %a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> -> tensor<128x32xf16, #dot_operand_a> - %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #ttg.shared_memory> -> tensor<32x256xf16, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = ttg.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> @@ -1133,13 +1143,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory>, %b:!ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory>) { + %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -1156,10 +1167,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=1}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory>, %b:!ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory>) { + %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -1167,8 +1179,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // CHECK-SAME: (i32, i32, i32, i32) - %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #ttg.shared_memory> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #ttg.shared_memory> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 @@ -1387,12 +1399,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> - %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> tt.return } } @@ -1400,14 +1413,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // ----- #blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared0 = #ttg.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #ttg.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return @@ -1644,20 +1658,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=2}> #dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @i16_mma_layout(%f16_inp: tensor<16x16xf16, #blocked0>, %i16_inp: tensor<16x16xi16, #blocked0>) { // CHECK-LABEL: @i16_mma_layout - %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> - %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #ttg.shared_memory> + %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> + %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #ttg.shared_memory> -> tensor<16x16xf16, #dot_operand_a> - %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #ttg.shared_memory> -> tensor<16x16xi16, #dot_operand_b> + %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> + %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b> // CHECK: llvm.sitofp %{{.*}} : i16 to f16 @@ -1715,13 +1730,14 @@ module attributes {"ttg.target" = "cuda:75", "ttg.num-ctas" = 1 : i32, "ttg.num- #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @vectorize_shmem_load // CHECK: llvm.load // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> // CHECK-NOT: llvm.load - tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #ttg.shared_memory>) { - %0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #ttg.shared_memory> -> tensor<16x16xi8, #blocked> + tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #smem>) { + %0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #smem> -> tensor<16x16xi8, #blocked> tt.return } } @@ -1730,13 +1746,14 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @vectorize_shmem_store // CHECK: llvm.store // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> // CHECK-NOT: llvm.store tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { - %0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #ttg.shared_memory> + %0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #smem> tt.return } } @@ -1756,14 +1773,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_load_bf16 // CHECK: llvm.extractelement {{.*}} : vector<8xbf16> tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 - %19 = ttg.local_alloc : () -> !ttg.memdesc<1x1x2048xbf16, #shared, #ttg.shared_memory, mutable> - %22 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x1x2048xbf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #ttg.shared_memory, mutable> - %39 = ttg.local_load %22 : !ttg.memdesc<1x2048xbf16, #shared, #ttg.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> + %19 = ttg.local_alloc : () -> !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> + %22 = ttg.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x1x2048xbf16, #shared, #smem, mutable> -> !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> + %39 = ttg.local_load %22 : !ttg.memdesc<1x2048xbf16, #shared, #smem, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return } @@ -1772,13 +1790,14 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_store // CHECK: llvm.store tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> - ttg.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + ttg.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> tt.return } } @@ -1786,14 +1805,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: test_local_store_subview // CHECK: llvm.store tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> - %sv = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> - ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> + %0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + %sv = ttg.memdesc_subview %0[%c0_i32] : !ttg.memdesc<1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> + ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable> tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 1b64ee70056c..56e078463d20 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -3,9 +3,10 @@ #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_high_precision_acc - tt.func @dot_high_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_high_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-COUNT-128: llvm.fadd // CHECK: nvgpu.wgmma @@ -16,7 +17,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-COUNT-128: llvm.fadd %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 32 : i32, inputPrecision = 0 : i32} : - !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> tt.return } } @@ -26,9 +27,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_low_precision_acc - tt.func @dot_low_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_low_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: nvgpu.wgmma @@ -40,7 +42,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.return %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 129 : i32, inputPrecision = 0 : i32} : - !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> tt.return } } @@ -50,9 +52,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @dot_mix_precision_acc - tt.func @dot_mix_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + tt.func @dot_mix_precision_acc(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #smem>, %c: tensor<128x256xf32, #mma>) { // CHECK: nvgpu.wgmma // CHECK-NOT: llvm.fadd // CHECK: nvgpu.wgmma @@ -64,7 +67,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK: llvm.return %m = ttng.warp_group_dot %a, %b, %c {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : - !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> tt.return } } @@ -74,14 +77,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_zero_acc // Generate a wgmma with 2 sources. // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { - tt.func @dot_zero_acc(%a: !ttg.memdesc<128x64xf16, #shared>, %b: !ttg.memdesc<64x64xf16, #shared1>) { + tt.func @dot_zero_acc(%a: !ttg.memdesc<128x64xf16, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared1, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %m = ttng.warp_group_dot %a, %b, %cst {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : - !ttg.memdesc<128x64xf16, #shared> * !ttg.memdesc<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> + !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> tt.return } } @@ -90,16 +94,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> - tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !ttg.memdesc<64x64xf16, #shared>) { + tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> %opA = ttg.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %m = ttng.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return } } @@ -109,15 +114,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A_fp8 // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !ttg.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !ttg.memdesc<128x256xf8E5M2, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> %m = ttng.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<128x256xf8E5M2, #shared, #smem> -> tensor<128x256xf32, #mma1> tt.return } } @@ -127,12 +133,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: dot_reg_operand_upcast - tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared>, %b: !ttg.memdesc<64x64xf16, #shared>, %acc: tensor<128x64xf32, #mma>) { - %a_dotop = ttg.local_load %a_desc : !ttg.memdesc<128x64xi8, #shared> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.func @dot_reg_operand_upcast(%a_desc: !ttg.memdesc<128x64xi8, #shared, #smem>, %b: !ttg.memdesc<64x64xf16, #shared, #smem>, %acc: tensor<128x64xf32, #mma>) { + %a_dotop = ttg.local_load %a_desc : !ttg.memdesc<128x64xi8, #shared, #smem> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %a_casted = arith.sitofp %a_dotop : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %res = ttng.warp_group_dot %a_casted, %b, %acc : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return } } @@ -218,13 +225,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: dot_zero_acc_operand // CHECK-COUNT-128: llvm.fadd - tt.func @dot_zero_acc_operand(%a: !ttg.memdesc<128x128xf8E5M2, #shared>, %b: !ttg.memdesc<128x128xf8E5M2, #shared1>) { + tt.func @dot_zero_acc_operand(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %b: !ttg.memdesc<128x128xf8E5M2, #shared1, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %m = ttng.warp_group_dot %a, %b, %cst {maxNumImpreciseAcc = 64 : i32, inputPrecision = 0 : i32} : - !ttg.memdesc<128x128xf8E5M2, #shared> * !ttg.memdesc<128x128xf8E5M2, #shared1> -> tensor<128x128xf32, #mma> + !ttg.memdesc<128x128xf8E5M2, #shared, #smem> * !ttg.memdesc<128x128xf8E5M2, #shared1, #smem> -> tensor<128x128xf32, #mma> tt.return } } @@ -234,12 +242,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#smem = #ttg.shared_memory // CHECK-LABEL: distribute_to_shared_st_matrix module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) { // CHECK-COUNT-16: nvgpu.stmatrix // CHECK: llvm.return - %b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, mutable> + %b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> tt.return } } diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index 52c30b28fc66..127c4951e383 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -1,11 +1,12 @@ // RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: init_barrier - tt.func @init_barrier(%alloc: !ttg.memdesc<1xi64, #shared0>) { + tt.func @init_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>) { // CHECK: "@$0 mbarrier.init.shared::cta.b64 [$1], 1;", "b,r" %{{.*}}, %{{.*}} : (i1, !llvm.ptr<3>) -> !llvm.void - ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem> tt.return } } @@ -13,13 +14,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: wait_barrier - tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0>, %phase: i32) { + tt.func @wait_barrier(%alloc: !ttg.memdesc<1xi64, #shared0, #smem>, %phase: i32) { // CHECK: waitLoop: // CHECK: mbarrier.try_wait.parity.shared.b64 // CHECK: @!P1 bra.uni waitLoop - ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0> + ttng.wait_barrier %alloc, %phase : !ttg.memdesc<1xi64, #shared0, #smem> tt.return } } @@ -29,14 +31,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_global_to_local // CHECK: elect.sync // CHECK: "@$0 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4}], [$5];", "b,r,l,r,r,r" {{.*}} : (i1, !llvm.ptr<3>, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.shared // CHECK: return - tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0>, %pred: i1) { - ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !ttg.memdesc<1xi64, #shared0> -> !ttg.memdesc<128x128xf32, #shared1, mutable> + tt.func @tma_copy_global_to_local(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>, %x: i32, %barrier: !ttg.memdesc<1xi64, #shared0, #smem>, %pred: i1) { + ttng.async_tma_copy_global_to_local %tma[%x, %x] %alloc, %barrier, %pred : !tt.ptr, !ttg.memdesc<1xi64, #shared0, #smem> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable> tt.return } } @@ -44,14 +47,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tma_copy_local_to_global // CHECK: elect.sync // CHECK: "@$0 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [$1, {$2, $3}], [$4];", "b,l,r,r,r" {{.*}} : (i1, !llvm.ptr<1>, i32, i32, !llvm.ptr<3>) -> !llvm.void // CHECK-NOT: cp.async.bulk.tensor.2d.global.shared::cta.bulk_group // CHECK: cp.async.bulk.commit_group - tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1>, %x: i32) { - ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : , <128x128xf32, #shared1> + tt.func @tma_copy_local_to_global(%tma: !tt.ptr, %alloc: !ttg.memdesc<128x128xf32, #shared1, #smem>, %x: i32) { + ttng.async_tma_copy_local_to_global %tma[%x, %x] %alloc : !tt.ptr, !ttg.memdesc<128x128xf32, #shared1, #smem> tt.return } } @@ -71,11 +75,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: expect_barrier // CHECK: @$0 mbarrier.arrive.expect_tx.shared.b64 _, [$1], 16384; - tt.func @expect_barrier(%barrier: !ttg.memdesc<1xi64, #shared0, mutable>, %pred: i1) { - ttng.barrier_expect %barrier, 16384, %pred : <1xi64, #shared0, mutable> + tt.func @expect_barrier(%barrier: !ttg.memdesc<1xi64, #shared0, #smem, mutable>, %pred: i1) { + ttng.barrier_expect %barrier, 16384, %pred : <1xi64, #shared0, #smem, mutable> tt.return } } diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index ce660d4228a7..3e130c29031d 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -276,10 +276,11 @@ tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64 #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> #shared2 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1], hasLeadingOffset = true}> #shared3 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { -tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared>, %arg1: !ttg.memdesc<16x32xf32, #shared2>) { - %a = ttg.memdesc_trans %arg0 {order = array} : !ttg.memdesc<2x4x8x16xf32, #shared> -> !ttg.memdesc<4x16x8x2xf32, #shared1> - %b = ttg.memdesc_trans %arg1 {order = array} : !ttg.memdesc<16x32xf32, #shared2> -> !ttg.memdesc<32x16xf32, #shared3> +tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared, #smem>, %arg1: !ttg.memdesc<16x32xf32, #shared2, #smem>) { + %a = ttg.memdesc_trans %arg0 {order = array} : !ttg.memdesc<2x4x8x16xf32, #shared, #smem> -> !ttg.memdesc<4x16x8x2xf32, #shared1, #smem> + %b = ttg.memdesc_trans %arg1 {order = array} : !ttg.memdesc<16x32xf32, #shared2, #smem> -> !ttg.memdesc<32x16xf32, #shared3, #smem> tt.return } } // end module diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir index c5302913c47b..7ed7db0c1e3e 100644 --- a/test/TritonGPU/accumulator-init.mlir +++ b/test/TritonGPU/accumulator-init.mlir @@ -6,19 +6,20 @@ #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @constant_init // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] - tt.func @constant_init(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @constant_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -27,14 +28,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @constant_init_integer // CHECK-DAG: %[[FALSE:.+]] = arith.constant false // CHECK: ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] - tt.func @constant_init_integer(%A: !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xi8, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { + tt.func @constant_init_integer(%A: !ttg.memdesc<128x64xi8, #shared, #smem>, %B: !ttg.memdesc<64x16xi8, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xi32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0> : tensor<128x16xi32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xi32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xi8, #shared1, #ttg.shared_memory> -> tensor<128x16xi32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %cst_2 : !ttg.memdesc<128x64xi8, #shared, #smem> * !ttg.memdesc<64x16xi8, #shared1, #smem> -> tensor<128x16xi32, #mma1> scf.yield %acc: tensor<128x16xi32, #mma1> } tt.return %17 : tensor<128x16xi32, #mma1> @@ -53,14 +54,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @if_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -84,14 +85,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @if_after_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_after_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %acc : tensor<128x16xf32, #mma1> } else { @@ -115,7 +116,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: scf.yield %[[ACC]] // CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -127,7 +128,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { } else { scf.yield %arg4 : tensor<128x16xf32, #mma1> } - %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -146,7 +147,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: scf.yield %[[ACC]] // CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_mma_invert(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -158,7 +159,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { } else { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } - %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -173,14 +174,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @sel_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @sel_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -196,7 +197,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] // CHECK: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @sel_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @sel_before_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -204,7 +205,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> - %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> @@ -230,7 +231,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: else // CHECK: scf.yield %[[ACC_NEXT]] // CHECK: scf.yield {{.*}}, %[[TRUE]] - tt.func @if_before_and_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_before_and_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 @@ -242,7 +243,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { } else { scf.yield %arg4 : tensor<128x16xf32, #mma1> } - %acc = ttng.warp_group_dot %A, %B, %acc_0 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_0 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -270,14 +271,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: else // CHECK: scf.yield %[[ACC_CND]] // CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] - tt.func @two_ifs_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @two_ifs_after_mma(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -297,14 +298,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @non_zero_init // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc - tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @non_zero_init(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -313,14 +314,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @zero_init_dist_2 // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc - tt.func @zero_init_dist_2(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @zero_init_dist_2(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg5 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg5 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -329,7 +330,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @if_defines_alternative // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc - tt.func @if_defines_alternative(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @if_defines_alternative(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> @@ -337,7 +338,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %cnd = arith.cmpi slt, %arg3, %ext : i32 - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { scf.yield %cst_2 : tensor<128x16xf32, #mma1> } else { @@ -351,14 +352,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @non_cond_override // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc - tt.func @non_cond_override(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + tt.func @non_cond_override(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { - %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> scf.yield %acc_: tensor<128x16xf32, #mma1> } @@ -368,14 +369,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // If the condition is a tensor skip the optimization. // CHECK-LABEL: @negative_sel_tensor // CHECK-NOT: %[[ACC_NEXT:.+]] = ttng.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !ttg.memdesc - tt.func @negative_sel_tensor(%A: !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory>, %B: !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { + tt.func @negative_sel_tensor(%A: !ttg.memdesc<128x64xf16, #shared, #smem>, %B: !ttg.memdesc<64x16xf16, #shared1, #smem>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { %c0_i32 = arith.constant 0 : i32 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> - %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %acc = ttng.warp_group_dot %A, %B, %acc_ : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> scf.yield %acc: tensor<128x16xf32, #mma1> } tt.return %17 : tensor<128x16xf32, #mma1> diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 8fa6d6fe1215..3e9863c675bf 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -14,6 +14,7 @@ #blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 @@ -34,10 +35,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 { %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> - %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> - %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem> + %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> %107 = arith.addi %arg26, %c128_i64 : i64 scf.yield %107 : i64 @@ -58,6 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx90a", "ttg.threads-per-warp" = 64 : i32} { tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 @@ -78,10 +80,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma> %73 = tt.splat %3 : !tt.ptr -> tensor<128x128x!tt.ptr, #blocked2> %74 = tt.load %73 : tensor<128x128x!tt.ptr, #blocked2> - %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> - %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> - %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> - %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %75 = ttg.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !ttg.memdesc<256x128xf16, #shared, #smem> + %76 = ttg.local_load %75 : !ttg.memdesc<256x128xf16, #shared, #smem> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %77 = ttg.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %78 = ttg.local_load %77 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> %79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> %107 = arith.addi %arg26, %c128_i64 : i64 scf.yield %107 : i64 @@ -94,7 +96,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> - +#smem = #ttg.shared_memory // CHECK-LABEL: order_load_alloc_local_load_local_store // CHECK: %[[LOAD:.+]] = tt.load // CHECK: %[[ALLOC:.+]] = ttg.local_alloc @@ -104,10 +106,10 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %10 = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #shared, mutable> - ttg.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !ttg.memdesc<32x32xf32, #shared, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> + ttg.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !ttg.memdesc<32x32xf32, #shared, #smem, mutable> %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, #smem, mutable> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> @@ -175,6 +177,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} #shared2 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> #shared3 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> #shared4 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32, ttg.target = "hip:gfx942"} { // CHECK-LABEL: tt.func @matmul_loop @@ -222,22 +225,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> - %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> %12 = arith.cmpi slt, %arg0, %arg1 : index %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable>) { + %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + %19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) { %20 = arith.subi %arg1, %arg2 : index %21 = arith.cmpi slt, %arg5, %20 : index - %22 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %23 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %22 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %23 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %24 = arith.mulf %23, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> %26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> @@ -249,14 +252,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %32 = arith.addi %arg9, %c1_i32 : i32 %33 = arith.cmpi slt, %32, %c1_i32 : i32 %34 = arith.select %33, %32, %c0_i32 : i32 - %35 = ttg.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - %36 = ttg.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - ttg.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> + %35 = ttg.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %36 = ttg.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> } - ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> tt.return %19#2 : tensor<128x128xf32, #mma> } @@ -313,8 +316,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> - %10 = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> - %11 = ttg.local_alloc : () -> !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> + %10 = ttg.local_alloc : () -> !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> %12 = arith.cmpi slt, %arg0, %arg1 : index %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> %14 = tt.load %4, %13 : tensor<128x32x!tt.ptr, #blocked1> @@ -328,16 +331,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %22 = tt.load %19, %21 : tensor<128x32x!tt.ptr, #blocked1> %23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked> %24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr, #blocked> - %25 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - %26 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - ttg.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { + %25 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %26 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + %27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) { %28 = arith.muli %arg2, %c2 : index %29 = arith.subi %arg1, %28 : index %30 = arith.cmpi slt, %arg5, %29 : index - %31 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %32 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %31 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %32 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %33 = arith.mulf %32, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> %35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> @@ -349,14 +352,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %41 = arith.addi %arg9, %c1_i32 : i32 %42 = arith.cmpi slt, %41, %c2_i32 : i32 %43 = arith.select %42, %41, %c0_i32 : i32 - %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> - %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - ttg.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable> - scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> + %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + ttg.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> + %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + ttg.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128> + scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked> } - ttg.local_dealloc %10 : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_dealloc %11 : !ttg.memdesc<2x32x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_dealloc %10 : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> tt.return %27#2 : tensor<128x128xf32, #mma> } @@ -404,8 +407,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %c0 = arith.constant 0 : index %c1_i32 = arith.constant 1 : i32 %cst_0 = arith.constant dense<1> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> - %1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> %2 = arith.cmpi sgt, %arg1, %c0 : index %3 = tt.splat %2 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> %4 = tt.load %arg3, %3 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> @@ -421,17 +424,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %14 = tt.load %12, %13 : tensor<16x16x!tt.ptr, #blocked> %15 = tt.splat %5 : i1 -> tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked}>> %16 = tt.load %6, %15 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>> - %17 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - ttg.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - %18 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - ttg.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) { + %17 = ttg.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %18 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) { %20 = arith.subi %arg1, %c2 : index %21 = arith.cmpi slt, %arg6, %20 : index %22 = arith.subi %arg1, %c1 : index %23 = arith.cmpi slt, %arg6, %22 : index - %24 = ttg.local_load %arg11 : !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %25 = ttg.local_load %arg12 : !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = ttg.local_load %arg11 : !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %25 = ttg.local_load %arg12 : !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> %27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr, #blocked1>, tensor<16x16xi32, #blocked1> %28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> @@ -448,14 +451,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %39 = arith.addi %arg10, %c1_i32 : i32 %40 = arith.cmpi slt, %39, %c1_i32 : i32 %41 = arith.select %40, %39, %c0_i32 : i32 - %42 = ttg.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - ttg.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - %43 = ttg.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - ttg.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable> - scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, !ttg.memdesc<16x16xf16, #shared2, #ttg.shared_memory, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> + %42 = ttg.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + %43 = ttg.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + ttg.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16> + scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr, #blocked1>, tensor<16x!tt.ptr, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>> } - ttg.local_dealloc %0 : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> - ttg.local_dealloc %1 : !ttg.memdesc<1x16x16xf16, #shared2, #ttg.shared_memory, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> tt.return %19#0 : tensor<16x16xf32, #mma> } } @@ -463,18 +466,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // ----- // CHECK-LABEL: sink_convert_dealloc -// CHECK-COUNT-2: ttg.local_dealloc %{{.+}} : !ttg.memdesc<4x128x64xf16, #shared, mutable> +// CHECK-COUNT-2: ttg.local_dealloc %{{.+}} : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> // CHECK: ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> - %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> - ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> - ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } @@ -488,14 +492,15 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { - %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> gpu.barrier %2 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %1 = ttg.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<4x128x64xf16, #shared, mutable> - ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> - ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> + %1 = ttg.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> tt.return } } diff --git a/test/TritonGPU/amd/amd-sched-2nd-load.mlir b/test/TritonGPU/amd/amd-sched-2nd-load.mlir index 248a04a3c0da..24139f66be5e 100644 --- a/test/TritonGPU/amd/amd-sched-2nd-load.mlir +++ b/test/TritonGPU/amd/amd-sched-2nd-load.mlir @@ -10,6 +10,7 @@ #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory // Category 1: Single dot with two loads, we make sure the optimization is applied when tile size is large enough // The following tile sizes should apply the optimization @@ -30,18 +31,18 @@ // CHECK-NEXT: ttg.local_store %[[tileA]] // CHECK-NEXT: ttg.local_store %[[tileB]] module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory, mutable>) { + tt.func public @sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x256x!tt.ptr, #blocked1> - %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<128x256xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> @@ -58,6 +59,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory // Should apply: tile size 256x256x64 with single dot // CHECK-LABEL: sink_2nd_load_256x256x64 @@ -69,18 +71,18 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // CHECK-NEXT: ttg.local_store %[[tileA]] // CHECK-NEXT: ttg.local_store %[[tileB]] module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) { + tt.func public @sink_2nd_load_256x256x64(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> - %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #smem, mutable> -> tensor<256x64xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> @@ -97,6 +99,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory // Should NOT apply: tile size 256x64x128 with single dot // CHECK-LABEL: sink_2nd_load_256x64x128 @@ -108,18 +111,18 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // CHECK-NEXT: ttg.local_store %[[tileA]] // CHECK-NEXT: ttg.local_store %[[tileB]] module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory, mutable>) { + tt.func public @sink_2nd_load_256x64x128(%A_ptr: tensor<256x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x64x!tt.ptr, #blocked1>, %C_ptr: tensor<256x64x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x128x!tt.ptr, #blocked> - %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x128xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x64x!tt.ptr, #blocked1> - %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x64xf16, #dotOp1> -> tensor<256x64xf32, #mma> - ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> scf.yield %3 : tensor<256x64xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x64x!tt.ptr, #mma> @@ -136,6 +139,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory // Should NOT apply: tile size 256x256x32 with single dot // CHECK-LABEL: sink_2nd_load_256x256x32 @@ -147,18 +151,18 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // CHECK-NEXT: ttg.local_store %[[tileA]] // CHECK-NEXT: ttg.local_store %[[tileB]] module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<32x256xf16, #shared1, #ttg.shared_memory, mutable>) { + tt.func public @sink_2nd_load_256x256x32(%A_ptr: tensor<256x32x!tt.ptr, #blocked>, %B_ptr: tensor<32x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x32xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<32x256xf16, #shared1, #smem, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x32x!tt.ptr, #blocked> - %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x32xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x32xf16, #shared, #smem, mutable> -> tensor<256x32xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<32x256x!tt.ptr, #blocked1> - %2 = ttg.local_load %B_LDS : !ttg.memdesc<32x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<32x256xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<32x256xf16, #shared1, #smem, mutable> -> tensor<32x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x32xf16, #dotOp0> * tensor<32x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - ttg.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !ttg.memdesc<32x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x32xf16, #blocked> -> !ttg.memdesc<256x32xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<32x256xf16, #blocked1> -> !ttg.memdesc<32x256xf16, #shared1, #smem, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> @@ -175,6 +179,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory // Category 2: single dot with two loads and tile size is large enough (128x128x128). // We make sure the move is legal. @@ -188,18 +193,18 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // CHECK-NEXT: tt.dot // CHECK-NEXT: ttg.local_store %[[tileA]] module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory, mutable>) { + tt.func public @sink_2nd_load_128x128x128_user_before_dot(%A_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr: tensor<128x128x!tt.ptr, #blocked>, %B_ptr2: tensor<128x128x!tt.ptr, #blocked>, %C_ptr: tensor<128x128x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x128xf16, #shared1, #smem, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<128x128xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<128x128x!tt.ptr, #blocked> - %1 = ttg.local_load %A_LDS : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x128xf16, #dotOp0> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #dotOp0> %5 = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked> - %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<128x128xf16, #dotOp1> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<128x128xf16, #shared1, #smem, mutable> -> tensor<128x128xf16, #dotOp1> tt.store %B_ptr, %5 : tensor<128x128x!tt.ptr, #blocked> %3 = tt.dot %1, %2, %arg1 : tensor<128x128xf16, #dotOp0> * tensor<128x128xf16, #dotOp1> -> tensor<128x128xf32, #mma> - ttg.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> scf.yield %3 : tensor<128x128xf32, #mma> } tt.store %C_ptr, %0#0: tensor<128x128x!tt.ptr, #mma> @@ -228,20 +233,21 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}> #dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) { + tt.func public @sink_2nd_load_256x256x64_two_dot(%A_ptr: tensor<256x64x!tt.ptr, #blocked>, %B_ptr: tensor<64x256x!tt.ptr, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr, #mma>, %A_LDS: !ttg.memdesc<256x64xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 { %4 = tt.load %A_ptr : tensor<256x64x!tt.ptr, #blocked> %5 = tt.load %B_ptr : tensor<64x256x!tt.ptr, #blocked1> - %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #dotOp0> - %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #dotOp1> + %1 = ttg.local_load %A_LDS : !ttg.memdesc<256x64xf16, #shared, #smem, mutable> -> tensor<256x64xf16, #dotOp0> + %2 = ttg.local_load %B_LDS : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #dotOp1> %3 = tt.dot %1, %2, %arg1 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> %6 = tt.dot %1, %2, %3 : tensor<256x64xf16, #dotOp0> * tensor<64x256xf16, #dotOp1> -> tensor<256x256xf32, #mma> - ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> - ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %4, %A_LDS : tensor<256x64xf16, #blocked> -> !ttg.memdesc<256x64xf16, #shared, #smem, mutable> + ttg.local_store %5, %B_LDS : tensor<64x256xf16, #blocked1> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> scf.yield %3 : tensor<256x256xf32, #mma> } tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr, #mma> diff --git a/test/TritonGPU/amd/optimize-lds-usage.mlir b/test/TritonGPU/amd/optimize-lds-usage.mlir index e9d71ed908b5..38f2f21eeef6 100644 --- a/test/TritonGPU/amd/optimize-lds-usage.mlir +++ b/test/TritonGPU/amd/optimize-lds-usage.mlir @@ -11,11 +11,12 @@ #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { - %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %2 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -33,11 +34,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} { - %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %2 = ttg.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %3 = ttg.local_load %1 : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -55,11 +57,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} { - %1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !ttg.memdesc<1x128x128xf16, #shared, #smem> %2 = ttg.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #ttg.shared_memory> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %3 = ttg.local_load %1 : !ttg.memdesc<1x128x128xf16, #shared, #smem> -> tensor<1x128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } @@ -79,11 +82,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} { - %1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> + %1 = ttg.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %2 = ttg.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma> - %3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #ttg.shared_memory> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> + %3 = ttg.local_load %1 : !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>> tt.return } } @@ -98,23 +102,24 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} // CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> // CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}}) -// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #ttg.shared_memory> +// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem> // CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]> // CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]> // CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>> -// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #ttg.shared_memory> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> +// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>> #blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma1 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}> #mma2 = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> #dotop1 = #ttg.dot_op<{opIdx=0, parent=#mma1, kWidth=4}> #dotop2 = #ttg.dot_op<{opIdx=0, parent=#mma2, kWidth=4}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @mfma_dot_shortcut(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>, %arg2: tensor<256x128xf16, #mma2>) attributes {noinline = false} { - %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %alloc = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> %convert_1 = ttg.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma1> %convert_2 = ttg.convert_layout %arg2 : tensor<256x128xf16, #mma2> -> tensor<256x128xf16, #dotop2> - %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf16, #dotop1> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf16, #shared, #smem> -> tensor<128x128xf16, #dotop1> tt.return } } @@ -129,11 +134,12 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> #mma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @convert_1d(%arg0: tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} { - %alloc = ttg.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory> + %alloc = ttg.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !ttg.memdesc<128x128xf32, #shared, #smem> %1 = ttg.convert_layout %arg0 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked> - %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf32, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> + %load = ttg.local_load %alloc : !ttg.memdesc<128x128xf32, #shared, #smem> -> tensor<128x128xf32, #mma> tt.return } } diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 70147ddfdfff..b47005b56978 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -69,10 +69,11 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.compute-capability" = 80} { tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> { - %0 = ttg.local_alloc : () -> !ttg.memdesc<256xi32, #shared, mutable> - %1 = ttg.local_load %0 : !ttg.memdesc<256xi32, #shared, mutable> -> tensor<256xi32, #blocked> + %0 = ttg.local_alloc : () -> !ttg.memdesc<256xi32, #shared, #smem, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked> gpu.barrier %2 = ttg.convert_layout %1 : tensor<256xi32, #blocked> -> tensor<256xi32, #blocked1> tt.return %2 : tensor<256xi32, #blocked1> @@ -83,17 +84,18 @@ tt.func @test_canonicalize_convert_local_load() -> tensor<256xi32, #blocked1> { #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: local_alloc_nofold1 - tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> { + tt.func @local_alloc_nofold1(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> { // CHECK: %[[ARG:.+]] = ttg.local_alloc // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] // CHECK-NEXT: tt.return %[[ARG3]] - %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> - %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #blocked> - %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> - tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem, mutable> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem, mutable> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> } } // end module @@ -103,17 +105,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared1 = #ttg.shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: local_alloc_nofold2 - tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #ttg.shared_memory> { + tt.func @local_alloc_nofold2(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #smem> { // CHECK: %[[ARG:.+]] = ttg.local_alloc // CHECK-NEXT: %[[ARG2:.+]] = ttg.local_load %[[ARG]] // CHECK-NEXT: %[[ARG3:.+]] = ttg.local_alloc %[[ARG2]] // CHECK-NEXT: tt.return %[[ARG3]] - %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> - %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #blocked> - %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #ttg.shared_memory> - tt.return %2 : !ttg.memdesc<16x16xf16, #shared1, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared1, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared1, #smem> } } // end module @@ -122,14 +125,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> #shared = #ttg.shared<{vec = 1, perPhase=2, maxPhase=8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { - tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> { + tt.func @local_alloc_fold(%arg0: tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> { // CHECK-LABEL: local_alloc_fold // CHECK-NEXT: %[[ARG:.+]] = ttg.local_alloc // CHECK-NEXT: tt.return %[[ARG]] - %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> - %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #blocked> - %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> - tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> + %0 = ttg.local_alloc %arg0 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + %1 = ttg.local_load %0 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #blocked> + %2 = ttg.local_alloc %1 : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> + tt.return %2 : !ttg.memdesc<16x16xf16, #shared, #smem> } } // end module diff --git a/test/TritonGPU/coalesce-async-copy.mlir b/test/TritonGPU/coalesce-async-copy.mlir index 0190238da135..e0e4f0077b07 100644 --- a/test/TritonGPU/coalesce-async-copy.mlir +++ b/test/TritonGPU/coalesce-async-copy.mlir @@ -7,13 +7,14 @@ // CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, - %view: !ttg.memdesc<64x16xi8, #shared, #ttg.shared_memory, mutable>, + %view: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>, %mask: tensor<64x16xi1, #blocked>, %other: tensor<64x16xi8, #blocked>) { - %token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #ttg.shared_memory, mutable> + %token = ttg.async_copy_global_to_local %input, %view mask %mask other %other: tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #smem, mutable> tt.return } } @@ -25,11 +26,12 @@ tt.func @async_copy_i8(%input: tensor<64x16x!tt.ptr, #blocked>, // CHECK: %{{.*}} = ttg.async_copy_global_to_local %{{.*}}: tensor<64x16x!tt.ptr, #[[NEW_BLOCKED]]> #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @async_copy_i8_no_mask_or_other(%input: tensor<64x16x!tt.ptr, #blocked>, - %view: !ttg.memdesc<64x16xi8, #shared, #ttg.shared_memory, mutable>) { - %token = ttg.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #ttg.shared_memory, mutable> + %view: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>) { + %token = ttg.async_copy_global_to_local %input, %view : tensor<64x16x!tt.ptr, #blocked> -> <64x16xi8, #shared, #smem, mutable> tt.return } } diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index a980e19efe62..7c956192b171 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1534,6 +1534,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, #blocked5 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} { tt.func public @reduce_cvt3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> @@ -1561,9 +1562,9 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, %20 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked4> %21 = tt.load %20 : tensor<32x32x!tt.ptr, #blocked4> %22 = ttg.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> - %23 = ttg.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared> - %24 = ttg.memdesc_trans %23 {order=array} : !ttg.memdesc<32x32xf16, #shared> -> !ttg.memdesc<32x32xf16, #shared1> - %25 = ttg.local_load %24 : !ttg.memdesc<32x32xf16, #shared1> -> tensor<32x32xf16, #blocked> + %23 = ttg.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !ttg.memdesc<32x32xf16, #shared, #smem> + %24 = ttg.memdesc_trans %23 {order=array} : !ttg.memdesc<32x32xf16, #shared, #smem> -> !ttg.memdesc<32x32xf16, #shared1, #smem> + %25 = ttg.local_load %24 : !ttg.memdesc<32x32xf16, #shared1, #smem> -> tensor<32x32xf16, #blocked> %26 = ttg.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked5}>> %27 = ttg.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked5}>> %28 = ttg.convert_layout %cst : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked5> @@ -1946,6 +1947,7 @@ tt.func public @yield_outside_loop2(%arg0: i32, %arg1: i32) -> (i32, i32) { #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: [[$BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[$MMA:#.*]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> @@ -1993,10 +1995,10 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- %67 = tt.load %66 : tensor<32x64x!tt.ptr, #blocked> %68 = tt.addptr %17, %65 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> %69 = tt.load %68 : tensor<256x64x!tt.ptr, #blocked> - %70 = ttg.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared> - %71 = ttg.memdesc_trans %70 {order=array} : !ttg.memdesc<256x64xf16, #shared> -> !ttg.memdesc<64x256xf16, #shared1> + %70 = ttg.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> + %71 = ttg.memdesc_trans %70 {order=array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> %72 = ttg.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> - %73 = ttg.local_load %71 : !ttg.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> + %73 = ttg.local_load %71 : !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> %74 = ttg.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> %75 = ttg.convert_layout %72 : tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked3}>> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %76 = ttg.convert_layout %73 : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked3}>> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> @@ -2692,12 +2694,13 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [2, 1, 16, 1, 1], warpsPerCTA = [1, 2, 2, 1, 1], order = [4, 0, 3, 2, 1]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 0, 1, 2, 3], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: lift_convert_to_local_load // CHECK-NOT: convert_layout // CHECK: tt.return - tt.func public @lift_convert_to_local_load(%arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #ttg.shared_memory, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { - %1 = ttg.local_load %arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #ttg.shared_memory, mutable> -> tensor<2x1x32x4x4xi8, #blocked> + tt.func public @lift_convert_to_local_load(%arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #smem, mutable>) -> tensor<2x4x32x1x4xi8, #blocked2> { + %1 = ttg.local_load %arg0 : !ttg.memdesc<2x1x32x4x4xi8, #shared, #smem, mutable> -> tensor<2x1x32x4x4xi8, #blocked> %2 = tt.trans %1 {order = array} : tensor<2x1x32x4x4xi8, #blocked> -> tensor<2x4x32x1x4xi8, #blocked1> %3 = ttg.convert_layout %2 : tensor<2x4x32x1x4xi8, #blocked1> -> tensor<2x4x32x1x4xi8, #blocked2> tt.return %3 : tensor<2x4x32x1x4xi8, #blocked2> diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 4346e1697af7..17fe2bfaa6ed 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -159,13 +159,14 @@ tt.func @update_kwidth_slice( #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A // CHECK: %[[A:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> -tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !ttg.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ - %A = ttg.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !ttg.memdesc<128x64xf16, #shared1> - %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !ttg.memdesc<64x64xf16, #shared, #smem>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = ttg.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -175,13 +176,14 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !ttg.memde #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 // CHECK: %[[A:.+]] = ttg.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> -// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> -tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !ttg.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ - %A = ttg.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !ttg.memdesc<128x64xf8E5M2, #shared1> - %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf8E5M2, #shared1> * !ttg.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: ttng.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !ttg.memdesc<64x64xf8E5M2, #shared, #smem> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !ttg.memdesc<64x64xf8E5M2, #shared, #smem>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = ttg.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !ttg.memdesc<128x64xf8E5M2, #shared1, #smem> + %r = ttng.warp_group_dot %A, %arg1, %arg2 : !ttg.memdesc<128x64xf8E5M2, #shared1, #smem> * !ttg.memdesc<64x64xf8E5M2, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -189,6 +191,7 @@ tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @a_impl // CHECK-NOT: %[[SELECT:.*]] = arith.select {{.*}} : tensor<128x128xi1, #ttg.dot_op<{{.*}}>, tensor<128x128xf16, #ttg.dot_op<{{.*}}> @@ -215,17 +218,18 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_push_elementwise // CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> // CHECK: %[[A_DOTOP:.*]] = ttg.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> - tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> - %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> - %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -236,6 +240,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_push_elementwise_chained // CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> @@ -244,15 +249,15 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> - tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ +// CHECK: %[[R:.*]] = ttng.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> - %dota = ttg.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> - %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dota = ttg.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -264,15 +269,16 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mma_reorder_transpose // CHECK: ttg.local_alloc // CHECK: ttg.memdesc_trans // CHECK: ttng.warp_group_dot - tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !ttg.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !ttg.memdesc<64x64xf16, #shared, #smem>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a = tt.trans %t {order = array} : tensor<64x128xf16, #blocked1> -> tensor<128x64xf16, #blocked> - %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> - %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1> * !ttg.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dota = ttg.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %r = ttng.warp_group_dot %dota, %dotb, %dotc : !ttg.memdesc<128x64xf16, #shared1, #smem> * !ttg.memdesc<64x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> tt.return %r : tensor<128x64xf32, #mma> } } @@ -282,6 +288,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mmav2_reorder_transpose // CHECK: ttg.local_alloc diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index c85df2ff64ed..5780bf672f9e 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -5,14 +5,15 @@ #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: matmul_like_fence tt.func public @matmul_like_fence(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked2>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared> - %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1> + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> // CHECK: ttng.fence_async_shared - %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared> * !ttg.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> tt.return } } @@ -24,6 +25,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: fence_outside_loop tt.func public @fence_outside_loop(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked>) { @@ -31,15 +33,15 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %c64_i32 = arith.constant 64 : i32 %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 - %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared> - %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1> + %0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> // CHECK: ttng.fence_async_shared // CHECK: scf.for // CHECK-NOT: ttng.fence_async_shared // CHECK: ttng.warp_group_dot scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 { - %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared> * !ttg.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> } } tt.return diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 0a494006af12..41ff5cc763a5 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -1,45 +1,77 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @miss_encoding(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 - // expected-error @+1 {{element type}} + // expected-error @+1 {{,}} %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<8x16xf16> tt.return } // ----- -tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @miss_memory_space(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{,}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared> -> !ttg.memdesc<8x16xf16> + tt.return +} + +// ----- + +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @subview_element_ty(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{element type}} + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf16, #shared, #smem> + tt.return +} + +// ----- + +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @too_many_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = ttg.memdesc_subview %arg0[%zero, %zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc + %a = ttg.memdesc_subview %arg0[%zero, %zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc tt.return } // ----- -tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @too_few_offsets(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{offsets}} - %a = ttg.memdesc_subview %arg0[%zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc + %a = ttg.memdesc_subview %arg0[%zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc tt.return } // ----- -tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result rank}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<3x8x16xf32> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<3x8x16xf32, #shared, #smem> tt.return } // ----- -tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32>) { +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @result_dim_too_large(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 // expected-error @+1 {{result shape}} - %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32> -> !ttg.memdesc<32xf32> + %a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<32xf32, #shared, #smem> tt.return } diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir index 4842f7ed2809..539ea317c20c 100644 --- a/test/TritonGPU/loop-pipeline-cuda.mlir +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -5,6 +5,7 @@ #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { @@ -49,9 +50,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> - %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> - %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem> -> !ttg.memdesc<16x64xf16, #shared1, #smem> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -68,6 +69,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> @@ -140,9 +142,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory> - %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory> -> !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory> - %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #smem> -> !ttg.memdesc<64x32xf32, #shared1, #smem> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -167,11 +169,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @matmul_tma -// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #{{.+}}, #ttg.shared_memory, mutable> -// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #{{.+}}, #ttg.shared_memory, mutable> -// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3xi64, #{{.+}}, #ttg.shared_memory, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #{{.+}}, #smem, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #{{.+}}, #smem, mutable> +// CHECK-DAG: ttg.local_alloc : () -> !ttg.memdesc<3xi64, #{{.+}}, #smem, mutable> // CHECK-COUNT-3: ttng.init_barrier // CHECK-COUNT-4: ttng.async_tma_copy_global_to_local // CHECK: scf.for @@ -187,10 +190,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.tensordesc> -> tensor<128x64xf16, #blocked> - %2 = ttg.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> + %2 = ttg.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.tensordesc> -> tensor<64x256xf16, #blocked1> - %4 = ttg.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> - %5 = ttng.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<128x256xf32, #mma> + %4 = ttg.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem> + %5 = ttng.warp_group_dot %2, %4, %arg4 { inputPrecision = 0 : i32 } : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 } diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 69868cf50cf8..a0453ab928db 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -5,6 +5,7 @@ #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @load_two_users tt.func @load_two_users(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { @@ -47,9 +48,9 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory, mutable> - %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory, mutable> - %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared1, #smem, mutable> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem, mutable> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -67,6 +68,7 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> @@ -139,9 +141,9 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> %71 = ttg.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory, mutable> - %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory, mutable> - %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #ttg.shared_memory, mutable> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %72 = ttg.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !ttg.memdesc<32x64xf32, #shared, #smem, mutable> + %73 = ttg.memdesc_trans %72 {order=array} : !ttg.memdesc<32x64xf32, #shared, #smem, mutable> -> !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> + %74 = ttg.local_load %73 : !ttg.memdesc<64x32xf32, #shared1, #smem, mutable> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> %77 = ttg.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -245,6 +247,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1201", "ttg.threads-per-warp" = 32 : i32} { tt.func public @loop_with_dot_and_transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32, %arg4: tensor<32x32x!tt.ptr, #blocked1>, %arg5: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %c1_i32 = arith.constant 1 : i32 @@ -252,9 +255,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> %0 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<32x32xf32, #blocked>) : i32 { %2 = tt.load %arg4 : tensor<32x32x!tt.ptr, #blocked1> - %3 = ttg.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !ttg.memdesc<32x32xf32, #shared, #ttg.shared_memory> - %4 = ttg.memdesc_trans %3 {order = array} : !ttg.memdesc<32x32xf32, #shared, #ttg.shared_memory> -> !ttg.memdesc<32x32xf32, #shared1, #ttg.shared_memory> - %5 = ttg.local_load %4 : !ttg.memdesc<32x32xf32, #shared1, #ttg.shared_memory> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %3 = ttg.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %4 = ttg.memdesc_trans %3 {order = array} : !ttg.memdesc<32x32xf32, #shared, #smem> -> !ttg.memdesc<32x32xf32, #shared1, #smem> + %5 = ttg.local_load %4 : !ttg.memdesc<32x32xf32, #shared1, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %6 = ttg.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> scf.yield %7 : tensor<32x32xf32, #blocked> diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir similarity index 97% rename from test/TritonGPU/pipeline-hopper-remove-wait.mlir rename to test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir index 03c2e0732392..0846a44f6c1f 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/loop-pipeline-hopper-remove-wait.mlir @@ -7,6 +7,7 @@ #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: two_dependent_dot tt.func public @two_dependent_dot(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg11: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg12: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg13: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg14: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg15: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg16: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg17: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg20: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg21: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { @@ -108,20 +109,20 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %110 = tt.broadcast %109 : tensor<64x128xi64, #blocked> -> tensor<64x128xi64, #blocked> %111 = tt.addptr %101, %110 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi64, #blocked> %112 = tt.load %111 : tensor<64x128x!tt.ptr, #blocked> - %113 = ttg.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared> - %114 = ttg.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1> - %115 = ttng.warp_group_dot %113, %114, %cst :!ttg.memdesc<128x128xf16, #shared> * !ttg.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> + %113 = ttg.local_alloc %38 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem> + %114 = ttg.local_alloc %90 : (tensor<128x64xf16, #blocked2>) -> !ttg.memdesc<128x64xf16, #shared1, #smem> + %115 = ttng.warp_group_dot %113, %114, %cst :!ttg.memdesc<128x128xf16, #shared, #smem> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> - %117 = ttg.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared> + %117 = ttg.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !ttg.memdesc<64x128xf16, #shared, #smem> %118 = ttg.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. // CHECK: ttng.warp_group_dot - // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32} + // CHECK: ttng.warp_group_dot_wait {{.*}}, {{.*}} {pendings = 0 : i32} // CHECK: ttng.warp_group_dot // CHECK-NOT: ttng.warp_group_dot_wait // CHECK: scf.yield - %119 = ttng.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = ttng.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x128xf16, #shared, #smem> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 776caf099b48..138cebcf2a1b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -10,6 +10,7 @@ #C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#smem = #ttg.shared_memory // CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 @@ -19,11 +20,11 @@ // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !ttg.memdesc<2x128x32xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #ttg.shared_memory, mutable> +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 2x128x32> +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #smem, mutable, 2x128x32> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #ttg.shared_memory, mutable> +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #smem, mutable, 2x32x128> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] @@ -333,8 +334,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // %a = tt.load %a_tileptr : !tt.ptr, 1> // %b = tt.load %b_tileptr : !tt.ptr, 1> // -// %sa = ttg.local_alloc %a : (tensor<128x32xf16, #BA>) -> !ttg.memdesc<128x32xf16, #SA, #ttg.shared_memory> -// %sb = ttg.local_alloc %b : (tensor<32x128xf16, #BB>) -> !ttg.memdesc<32x128xf16, #SB, #ttg.shared_memory> +// %sa = ttg.local_alloc %a : (tensor<128x32xf16, #BA>) -> !ttg.memdesc<128x32xf16, #SA, #smem> +// %sb = ttg.local_alloc %b : (tensor<32x128xf16, #BB>) -> !ttg.memdesc<32x128xf16, #SB, #smem> // %c = ttng.warp_group_dot %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> @@ -354,6 +355,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_chained_single_load tt.func @dot_chained_single_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> { @@ -393,13 +395,13 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: scf.yield %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %21 = ttng.warp_group_dot %19, %20, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %21 = ttng.warp_group_dot %19, %20, %cst_2 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = ttg.memdesc_trans %20 {order=array} : !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> + %23 = ttg.memdesc_trans %20 {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> %24 = ttg.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> - %25 = ttng.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf32, #mma> + %25 = ttng.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -445,9 +447,9 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -471,6 +473,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: two_accumulator_escape tt.func @two_accumulator_escape(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { @@ -502,8 +505,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> // CHECK: %[[ALLOC1:.+]] = ttg.local_alloc // CHECK: %[[ALLOC2:.+]] = ttg.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for @@ -515,11 +518,11 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: scf.yield // CHECK: %{{.*}}:2 = ttng.warp_group_dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = ttng.warp_group_dot %19, %20, %arg6 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %21 = ttng.warp_group_dot %19, %20, %arg6 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> - %25 = ttng.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf32, #mma> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> + %25 = ttng.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -535,6 +538,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 16, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory // Make sure that if one of the load dot operand is not pipelined (and therefore not double buffered) we won't use // async dot. @@ -577,13 +581,13 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> - %37 = ttg.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !ttg.memdesc<128x64xf8E5M2, #shared, #ttg.shared_memory> - %38 = ttg.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !ttg.memdesc<64x256xf8E5M2, #shared1, #ttg.shared_memory> + %37 = ttg.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !ttg.memdesc<128x64xf8E5M2, #shared, #smem> + %38 = ttg.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !ttg.memdesc<64x256xf8E5M2, #shared1, #smem> // CHECK: ttg.local_alloc // CHECK: scf.for // CHECK: ttng.warp_group_dot // CHECK-NEXT: ttng.warp_group_dot_wait - %39 = ttng.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E5M2, #shared, #ttg.shared_memory> * !ttg.memdesc<64x256xf8E5M2, #shared1, #ttg.shared_memory> -> tensor<128x256xf32, #mma> + %39 = ttng.warp_group_dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x64xf8E5M2, #shared, #smem> * !ttg.memdesc<64x256xf8E5M2, #shared1, #smem> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> @@ -614,6 +618,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: async_following_sync tt.func @async_following_sync(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { @@ -657,8 +662,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] // CHECK-NOT: ttng.warp_group_dot_wait // CHECK: %[[DOT0:.+]] = ttng.warp_group_dot @@ -675,17 +680,17 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: ttng.warp_group_dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %dot0 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %dot1 = ttng.warp_group_dot %19, %20, %prev_dot1 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> + %c = ttg.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %23 = ttg.memdesc_trans %c {order=array} : !ttg.memdesc<64x16xf16, #shared1, #smem> -> !ttg.memdesc<16x64xf16, #shared, #smem> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = ttng.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = ttng.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !ttg.memdesc<16x64xf16, #shared, #smem> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } @@ -715,11 +720,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_multiple_store_pipeline tt.func public @tma_multiple_store_pipeline(%arg0: tensor<1xf32, #blocked>, %arg1: !tt.tensordesc>, %arg2: i32, %arg3: i32) attributes {noinline = false} { %c0_i32 = arith.constant 0 : i32 - // CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xf32, #shared, #ttg.shared_memory, mutable> + // CHECK: %[[ALLOC:.+]] = ttg.local_alloc : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable> // CHECK: scf.for scf.for %arg4 = %c0_i32 to %arg3 step %arg2 : i32 { %1 = arith.divsi %arg4, %arg2 : i32 @@ -749,6 +755,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 32]}> #shared = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 16, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: _kernel_matmul_dependency tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr, #blocked>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} { @@ -780,10 +787,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %11 = tt.broadcast %10 : tensor<128x1xi32, #blocked1> -> tensor<128x128xi32, #blocked1> %12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> %13 = tt.load %arg0 : tensor<128x128x!tt.ptr, #blocked> - %14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared> + %14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> %15 = tt.load %12 : tensor<128x128x!tt.ptr, #blocked1> - %16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1> - %17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FNUZ, #shared> * !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1> -> tensor<128x128xf32, #mma> + %16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> + %17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> * !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> -> tensor<128x128xf32, #mma> %18 = tt.splat %7 : f32 -> tensor<128x128xf32, #mma> %19 = arith.mulf %17, %18 : tensor<128x128xf32, #mma> %20 = scf.if %6 -> (tensor<128x128xf32, #mma>) { @@ -806,6 +813,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: dot_prologue_epilogue // COMMON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} @@ -852,9 +860,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> } %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> - %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> scf.yield %acc_zero : tensor<128x16xf32, #mma1> @@ -878,6 +886,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NOCANON-LABEL: pipeline_downstream_dependencies // CHECK-NOCANON: {{.*}}, {{.*}}, %[[EXT:.*]]: i32, {{.*}} @@ -917,9 +926,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> - %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma1> + %19 = ttg.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem> + %20 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %acc = ttng.warp_group_dot %19, %20, %arg4 : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -942,6 +951,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +#smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: dot_lhs_registers tt.func @dot_lhs_registers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { @@ -988,8 +998,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %b_block = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> %a_dotop = ttg.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> - %21 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared1, #ttg.shared_memory> -> tensor<128x16xf32, #mma> + %b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared1, #smem> + %21 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared1, #smem> -> tensor<128x16xf32, #mma> %25 = tt.addptr %arg5, %cst_3 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> %26 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %21, %25, %26 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index ebdccd3b7996..29d61e07a4e9 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -12,6 +12,7 @@ #C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}> #B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}> +#smem = #ttg.shared_memory // CHECK-LABEL: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 @@ -892,6 +893,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [0, 1], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // COMMON-LABEL: tt.func @load_two_users_incompatible_layouts tt.func @load_two_users_incompatible_layouts(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) { @@ -930,9 +932,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = ttg.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> - %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #ttg.shared_memory> -> !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> - %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #ttg.shared_memory> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = ttg.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + %25 = ttg.memdesc_trans %24 {order=array} : !ttg.memdesc<64x16xf16, #shared, #smem> -> !ttg.memdesc<16x64xf16, #shared1, #smem> + %26 = ttg.local_load %25 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -981,6 +983,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // For HIP, we only pipeline the inner loop for now. #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> @@ -1041,7 +1044,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: scf.for // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #ttg.shared_memory, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #ttg.shared_memory, mutable> +// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_subview {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable, 1x16> // CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] @@ -1342,6 +1345,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 2], instrShape = [16, 8]}> #shared = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [0, 1], hasLeadingOffset = false}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { tt.func public @nested_loops(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> @@ -1361,9 +1365,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} { %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %11 = ttg.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !ttg.memdesc<16x16xf32, #shared, #ttg.shared_memory> - %12 = ttg.memdesc_trans %11 {order = array} : !ttg.memdesc<16x16xf32, #shared, #ttg.shared_memory> -> !ttg.memdesc<16x16xf32, #shared1, #ttg.shared_memory> - %13 = ttg.local_load %12 : !ttg.memdesc<16x16xf32, #shared1, #ttg.shared_memory> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %11 = ttg.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !ttg.memdesc<16x16xf32, #shared, #smem> + %12 = ttg.memdesc_trans %11 {order = array} : !ttg.memdesc<16x16xf32, #shared, #smem> -> !ttg.memdesc<16x16xf32, #shared1, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<16x16xf32, #shared1, #smem> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> %15 = ttg.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index 1513ac60e89b..0262bad35227 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -47,3 +47,15 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w tt.return } } + +// ----- + +#shared0 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: memdesc + // CHECK-SAME: !ttg.memdesc<1x64x16xf16, #{{.+}}> + tt.func @memdesc(%d : !ttg.memdesc<1x64x16xf16, #shared0, #smem>) { + tt.return + } +} diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index ab070a73a547..208516b3bfab 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -9,7 +9,7 @@ #C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> - +#smem = #ttg.shared_memory // CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 @@ -48,24 +48,24 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A, #smem> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> - %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B, #smem> -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } @@ -103,24 +103,24 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %b_off = arith.constant dense<4> : tensor<16x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> - %a_init = ttg.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A, #smem> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> - %b_init = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B, #smem> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A>, !ttg.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = ttg.local_load %a : !ttg.memdesc<128x16xf8E5M2, #A> -> tensor<128x16xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A, #smem>, !ttg.memdesc<16x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x16xf8E5M2, #A, #smem> -> tensor<128x16xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x16xf8E5M2, #A_OP> -> tensor<128x16xf16, #A_OP> - %b_op = ttg.local_load %b : !ttg.memdesc<16x128xf16, #B> -> tensor<16x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<16x128xf16, #B, #smem> -> tensor<16x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x16xf16, #A_OP> * tensor<16x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x16x!tt.ptr, #AL>, tensor<128x16xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<16x128x!tt.ptr, #BL>, tensor<16x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x16x!tt.ptr, #AL> - %next_a = ttg.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x16xf8E5M2, #AL>) -> !ttg.memdesc<128x16xf8E5M2, #A, #smem> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<16x128x!tt.ptr, #BL> - %next_b = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<16x128xf16, #BL>) -> !ttg.memdesc<16x128xf16, #B, #smem> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A>, !ttg.memdesc<16x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x16x!tt.ptr, #AL>, tensor<16x128x!tt.ptr, #BL>, !ttg.memdesc<128x16xf8E5M2, #A, #smem>, !ttg.memdesc<16x128xf16, #B, #smem>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } @@ -183,6 +183,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #C = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = false}> #A_OP = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_OP = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> +#smem = #ttg.shared_memory // CHECK: tt.func @matmul_loop_mixed_amd // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 @@ -221,27 +222,25 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> %a_ = tt.load %a_ptr_init, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> + %a_init = ttg.local_alloc %a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %b_ = tt.load %b_ptr_init, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> + %b_init = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A> -> tensor<128x32xf8E5M2, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C>) { + %a_op_ = ttg.local_load %a : !ttg.memdesc<128x32xf8E5M2, #A, #smem> -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> - %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B> -> tensor<32x128xf16, #B_OP> + %b_op = ttg.local_load %b : !ttg.memdesc<32x128xf16, #B, #smem> -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other : tensor<128x32x!tt.ptr, #AL> - %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A> + %next_a = ttg.local_alloc %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> !ttg.memdesc<128x32xf8E5M2, #A, #smem> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B> + %next_b = ttg.local_alloc %b_ : (tensor<32x128xf16, #BL>) -> !ttg.memdesc<32x128xf16, #B, #smem> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A>, !ttg.memdesc<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, !ttg.memdesc<128x32xf8E5M2, #A, #smem>, !ttg.memdesc<32x128xf16, #B, #smem>, tensor<128x128xf32, #C> } tt.return %loop#4 : tensor<128x128xf32, #C> } } // end module - -// ----- diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index bbb0de2ad1e1..e293ab724847 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -2,7 +2,7 @@ // CHECK: #[[$SHARED:.*]] = #ttg.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} // CHECK-LABEL: apply_swizzle -// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !ttg.memdesc<16x256xf16, #[[$SHARED]], #ttg.shared_memory> +// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !ttg.memdesc<16x256xf16, #[[$SHARED]], #smem> #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index c95a1cedc5ad..700ed22be2b1 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -8,13 +8,14 @@ #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @convert_cannot_hoist(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %9 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %10 = ttg.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> - %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %10 = ttg.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %11 = ttg.local_load %10 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> @@ -26,20 +27,21 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} // CHECK-LABEL: sink_convert_dealloc // CHECK: ttg.async_wait {num = 0 : i32} -// CHECK: ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> -// CHECK: ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> +// CHECK: ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> +// CHECK: ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> // CHECK: %3 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} { - %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> - %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, mutable> + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + %1 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> %2 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1> ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, mutable> - ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, mutable> + ttg.local_dealloc %0 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> + ttg.local_dealloc %1 : !ttg.memdesc<4x128x64xf16, #shared, #smem, mutable> %3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1> tt.return } @@ -48,22 +50,23 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} // ----- // CHECK-LABEL: sink_convert_idx_1 -// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> -// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> // CHECK: tt.dot #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @sink_convert_idx_1(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> - %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %A = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS = ttg.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> - %AD = ttg.local_load %AS : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %AS = ttg.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD = ttg.local_load %AS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 : tensor<32x32x!tt.ptr, #blocked> @@ -75,27 +78,28 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} // check that we don't sink convert_layout if it has multi users // CHECK-LABEL: convert_cannot_sink -// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> // CHECK: tt.dot -// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> +// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> // CHECK: tt.dot #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2]}> #shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @convert_cannot_sink(%arg0: tensor<32x32x!tt.ptr, #blocked>) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> %B = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> - %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %A0 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS0 = ttg.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> - %AD0 = ttg.local_load %AS0 : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %AS0 = ttg.local_alloc %A0 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD0 = ttg.local_load %AS0 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %12 = tt.dot %AD0, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %A1 = tt.load %arg0 : tensor<32x32x!tt.ptr, #blocked> - %AS1 = ttg.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared> - %AD1 = ttg.local_load %AS1 : !ttg.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %AS1 = ttg.local_alloc %A1 : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem> + %AD1 = ttg.local_load %AS1 : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %13 = tt.dot %AD1, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> tt.return } diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir index a73d4259c185..2e95f5024f41 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -10,6 +10,7 @@ // CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> // CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> // CHECK: #[[$ATTR_4:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory // To regenerate this test case, run the command // triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ // utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \ @@ -54,33 +55,33 @@ // CHECK: %[[VAL_40:.*]] = arith.muli %[[VAL_33]], %[[VAL_11]] : i32 // CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_5]], %[[VAL_18]] : i32 // CHECK: %[[VAL_42:.*]] = arith.divsi %[[VAL_41]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_46:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.init_barrier %[[VAL_46]], 1 : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_47:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.init_barrier %[[VAL_47]], 1 : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_48:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.init_barrier %[[VAL_48]], 1 : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_46:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_46]], 1 : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_47:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_47]], 1 : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_48:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.init_barrier %[[VAL_48]], 1 : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> // CHECK: %[[VAL_49:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_12]] : i32 -// CHECK: %[[VAL_50:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_51:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_50:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_51:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: %[[VAL_52:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_52]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_52]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_54:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_54]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_53]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_54]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_53]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_55:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_56:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.barrier_expect %[[VAL_56]], 49152, %[[VAL_55]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_57:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_56:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_56]], 49152, %[[VAL_55]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_57:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: %[[VAL_58:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_58]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_59:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_58]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_59:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_60:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_60]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_59]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_60]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_59]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_61:.*]]:5 = scf.for %[[VAL_62:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_63:.*]] = %[[VAL_19]], %[[VAL_64:.*]] = %[[VAL_13]], %[[VAL_65:.*]] = %[[VAL_15]], %[[VAL_66:.*]] = %[[VAL_8]], %[[VAL_67:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { // CHECK: %[[VAL_68:.*]] = arith.subi %[[VAL_42]], %[[VAL_6]] : i32 // CHECK: %[[VAL_69:.*]] = arith.cmpi slt, %[[VAL_62]], %[[VAL_68]] : i32 @@ -89,37 +90,37 @@ // CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_71]], %[[VAL_70]], %[[VAL_12]] : i32 // CHECK: %[[VAL_73:.*]] = arith.xori %[[VAL_67]], %[[VAL_15]] : i32 // CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_71]], %[[VAL_67]], %[[VAL_73]] : i32 -// CHECK: %[[VAL_75:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_72]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.wait_barrier %[[VAL_75]], %[[VAL_74]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_76:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_77:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_78:.*]] = ttg.memdesc_trans %[[VAL_76]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_79:.*]] = ttng.warp_group_dot %[[VAL_77]], %[[VAL_78]], %[[VAL_63]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #ttg.shared_memory, mutable> -> tensor<128x256xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_80:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_79]], %[[VAL_77]], %[[VAL_78]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_75:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_72]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_75]], %[[VAL_74]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_76:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_77:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_78:.*]] = ttg.memdesc_trans %[[VAL_76]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_79:.*]] = ttng.warp_group_dot %[[VAL_77]], %[[VAL_78]], %[[VAL_63]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_80:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_79]], %[[VAL_77]], %[[VAL_78]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_64]], %[[VAL_13]] : i32 // CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_65]], %[[VAL_15]] : i32 // CHECK: %[[VAL_83:.*]] = arith.cmpi slt, %[[VAL_82]], %[[VAL_7]] : i32 // CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_82]], %[[VAL_12]] : i32 -// CHECK: %[[VAL_85:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_84]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.barrier_expect %[[VAL_85]], 49152, %[[VAL_69]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_86:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_85:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_84]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_85]], 49152, %[[VAL_69]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_86:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> // CHECK: %[[VAL_87:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_87]]{{\[}}%[[VAL_39]], %[[VAL_81]]] %[[VAL_86]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_88:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_87]]{{\[}}%[[VAL_39]], %[[VAL_81]]] %[[VAL_86]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> +// CHECK: %[[VAL_88:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: %[[VAL_89:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_89]]{{\[}}%[[VAL_40]], %[[VAL_81]]] %[[VAL_88]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_89]]{{\[}}%[[VAL_40]], %[[VAL_81]]] %[[VAL_88]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> -> <256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> // CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_81]], %[[VAL_84]], %[[VAL_72]], %[[VAL_74]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 // CHECK: } // CHECK: %[[VAL_90:.*]] = ttng.warp_group_dot_wait %[[VAL_91:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> // CHECK: %[[VAL_92:.*]] = ttg.async_wait {num = 0 : i32} -// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.inval_barrier %[[VAL_93]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.inval_barrier %[[VAL_94]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: %[[VAL_95:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttng.inval_barrier %[[VAL_95]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -// CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -// CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_93]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_94]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: %[[VAL_95:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_95]] : <1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3> +// CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_96:.*]] = arith.truncf %[[VAL_90]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]> // CHECK: %[[VAL_97:.*]] = ttg.convert_layout %[[VAL_96]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]> // CHECK: tt.experimental_descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_97]] : !tt.tensordesc>, tensor<128x256xf16, #[[$ATTR_0]]> diff --git a/test/TritonGPU/tritongpu_ops.mlir b/test/TritonGPU/tritongpu_ops.mlir deleted file mode 100644 index d3dc5277e17f..000000000000 --- a/test/TritonGPU/tritongpu_ops.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: triton-opt %s | triton-opt | FileCheck %s - -#shared0 = #ttg.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset=true}> - -module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: memdesc - // CHECK-SAME: !ttg.memdesc<1x64x16xf16, #{{.+}}> - tt.func @memdesc(%d : !ttg.memdesc<1x64x16xf16, #shared0>) { - tt.return - } -} diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 2085cbf21323..a042e282b374 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -2,6 +2,7 @@ #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: init_barrier // CHECK: local_alloc @@ -9,8 +10,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: init_barrier tt.func @init_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> tt.return } } @@ -19,6 +20,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: inval_barrier // CHECK: local_alloc @@ -28,9 +30,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: inval_barrier tt.func @inval_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.inval_barrier %alloc : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.inval_barrier %alloc : !ttg.memdesc<1xi64, #shared0, #smem, mutable> tt.return } } @@ -39,6 +41,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: barrier_expect // CHECK: local_alloc @@ -48,9 +51,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: barrier_expect tt.func @barrier_expect(%pred : i1) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #smem, mutable> tt.return } } @@ -59,6 +62,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: wait_barrier // CHECK: local_alloc @@ -68,9 +72,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: wait_barrier tt.func @wait_barrier(%phase : i32) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #ttg.shared_memory, mutable> - ttng.wait_barrier %alloc, %phase : <1xi64, #shared0, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.init_barrier %alloc, 1 : !ttg.memdesc<1xi64, #shared0, #smem, mutable> + ttng.wait_barrier %alloc, %phase : <1xi64, #shared0, #smem, mutable> tt.return } } @@ -80,6 +84,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @tma_load(%arg0: !tt.tensordesc>, %arg1: i32) -> tensor<128x64xf16, #blocked0> { // CHECK-LABEL: tma_load @@ -89,8 +94,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> - ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.tensordesc> -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } @@ -100,6 +105,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #shared0 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tma_store // CHECK: ttg.local_alloc @@ -108,8 +114,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-NEXT: ttg.local_alloc tt.func public @tma_store(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> - ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #ttg.shared_memory, mutable> + %alloc = ttg.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> + ttg.local_dealloc %alloc : !ttg.memdesc<128x64xi64, #shared0, #smem, mutable> tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.tensordesc>, tensor<128x256xf32, #blocked0> tt.return } From 01fb036c6100925f72bd18a63ad271dad26329a3 Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Tue, 3 Dec 2024 14:05:33 -0600 Subject: [PATCH 2/7] [Pipeliner] Handle masking for atomic_rmw (#5231) This commit is to support atomic_rmw in the function predicateOp to mask operations during scheduling. --- .../Pipeliner/PipeliningUtility.cpp | 7 ++ test/TritonGPU/loop-pipeline-hip.mlir | 68 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 2305f30beb06..5fd98355f9c5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -80,6 +80,13 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, storeOp.getMaskMutable().assign(mask); return op; } + if (auto atomicRMWOp = dyn_cast(op)) { + rewriter.setInsertionPoint(atomicRMWOp); + Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(), + atomicRMWOp.getMask(), pred); + atomicRMWOp.getMaskMutable().assign(mask); + return op; + } assert("don't know how to predicate this op" && false); return op; diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index a0453ab928db..d33335f5a519 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -266,3 +266,71 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// Check that the stream pipeliner updates atomic op in the k-loop correctly +// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw +// CHECK: scf.for +// CHECK: tt.atomic_rmw fadd, acq_rel, gpu +// CHECK: tt.dot +// CHECK: scf.yield + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant dense<32> : tensor<32x32xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c31_i32 = arith.constant 31 : i32 + %c32_i32 = arith.constant 32 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> + %3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked> + %4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %12 = tt.addptr %11, %8 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %13 = tt.splat %arg2 : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> + %14 = tt.addptr %13, %3 : tensor<32x1x!tt.ptr, #blocked>, tensor<32x1xi32, #blocked> + %15 = tt.broadcast %14 : tensor<32x1x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #blocked> + %16 = tt.addptr %15, %7 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %17 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked> + %18 = arith.cmpi slt, %1, %17 : tensor<32x1xi32, #blocked> + %19 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked> + %20 = arith.cmpi slt, %5, %19 : tensor<1x32xi32, #blocked> + %21 = tt.broadcast %18 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked> + %22 = tt.broadcast %20 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked> + %23 = arith.andi %21, %22 : tensor<32x32xi1, #blocked> + %24 = arith.addi %arg3, %c31_i32 : i32 + %25 = arith.divsi %24, %c32_i32 : i32 + %26 = arith.muli %arg4, %c32_i32 : i32 + %27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked> + %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { + %32 = tt.load %arg7 : tensor<32x32x!tt.ptr, #blocked> + %33 = tt.load %arg8 : tensor<32x32x!tt.ptr, #blocked> + %34 = triton_gpu.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %35 = triton_gpu.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> + %39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %40 = triton_gpu.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked> + %41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> + scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> + } + %29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> + %30 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #mma> + %31 = triton_gpu.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma> + tt.store %30, %29, %31 : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} From d3a94e01dc1f3c2f96b0f4657d98f14c76e97915 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Wed, 4 Dec 2024 00:02:42 +0000 Subject: [PATCH 3/7] [TESTS] Forward fix for CI break (#5323) PR #5231 was authored before the `triton_gpu` -> `ttg` rename and CI is currently broken. --- test/TritonGPU/loop-pipeline-hip.mlir | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index d33335f5a519..49f67bd076d4 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -276,9 +276,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: tt.dot // CHECK: scf.yield -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { %cst = arith.constant dense<32> : tensor<32x32xi32, #blocked> %c0_i32 = arith.constant 0 : i32 @@ -286,12 +286,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %c31_i32 = arith.constant 31 : i32 %c32_i32 = arith.constant 32 : i32 %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked> %2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked> - %4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> + %4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked> %6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked> %7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked> %8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked> @@ -317,19 +317,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked>) : i32 { %32 = tt.load %arg7 : tensor<32x32x!tt.ptr, #blocked> %33 = tt.load %arg8 : tensor<32x32x!tt.ptr, #blocked> - %34 = triton_gpu.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %35 = triton_gpu.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> + %34 = ttg.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %35 = ttg.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma> %37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xi32, #blocked> %39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> - %40 = triton_gpu.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked> + %40 = ttg.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked> %41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr, #blocked>, tensor<32x32x!tt.ptr, #blocked> } %29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma> - %30 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #mma> - %31 = triton_gpu.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma> + %30 = ttg.convert_layout %16 : tensor<32x32x!tt.ptr, #blocked> -> tensor<32x32x!tt.ptr, #mma> + %31 = ttg.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma> tt.store %30, %29, %31 : tensor<32x32x!tt.ptr, #mma> tt.return } From fa0c2bdfa4b907700624d7dd6ffbba2f9f8e10e4 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 4 Dec 2024 01:27:55 +0100 Subject: [PATCH 4/7] Search for `ptxas` only for cuda backend in `supports_tma` function (#5314) For other backends, `ptxas` may not be installed. Signed-off-by: Anatoly Myachev --- python/triton/_internal_testing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 5ba0be1e34f9..377eed877a51 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -142,11 +142,13 @@ def to_numpy(x): def supports_tma(byval_only=False): + if not is_cuda(): + return False _, cuda_version = _path_to_binary("ptxas") min_cuda_version = (12, 0) if byval_only else (12, 3) cuda_version_tuple = tuple(map(int, cuda_version.split("."))) assert len(cuda_version_tuple) == 2, cuda_version_tuple - return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version + return torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version def tma_skip_msg(byval_only=False): From 1d5e9a2470dbeb9dbd218099226d213786cead9a Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Wed, 4 Dec 2024 04:37:49 +0000 Subject: [PATCH 5/7] [LLVM] Update to llvm/llvm-project@1f20eee6dc36 (#5308) This pulls in the AMDGPU backend support for the gfx950 target. We need to fix the rewrites in `Combine.td` given that https://github.com/llvm/llvm-project/pull/112700 adds a new attribute for denorm mode for `arith.addf`. --------- Co-authored-by: Lei Zhang --- cmake/llvm-hash.txt | 2 +- lib/Dialect/Triton/Transforms/Combine.td | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 50d024794663..0952ab984cc9 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -86b69c31642e98f8357df62c09d118ad1da4e16a +1f20eee6dc367bd202895e3eedb03974a628ef16 diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index e3588f587757..1f1de2c717bd 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath), + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat< [(Constraint> $c), (ConstrainthasOneUse()">, "dot result has a single use">)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath), + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm), (TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)), [(Constraint> $c), (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc), From 134b3eb6d75742dd73aa28df0b3ec734acd8e228 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Wed, 4 Dec 2024 05:19:31 +0000 Subject: [PATCH 6/7] [AMD][BACKEND] Add gfx950 target definitions. (#5281) Enable new arch target since backend support has been added. --- third_party/amd/backend/include/hsa/amd_hsa_elf.h | 2 +- third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/amd/backend/include/hsa/amd_hsa_elf.h b/third_party/amd/backend/include/hsa/amd_hsa_elf.h index 74f15d7d7ab6..65a77f041bf3 100644 --- a/third_party/amd/backend/include/hsa/amd_hsa_elf.h +++ b/third_party/amd/backend/include/hsa/amd_hsa_elf.h @@ -136,7 +136,7 @@ enum : unsigned { EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d, EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e, - EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f, + EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f, EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050, EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051, EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 63fb972f7903..7ab6fd68a5d5 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -11,6 +11,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { // CDNA ISA cases switch (kind) { + case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: case llvm::AMDGPU::GK_GFX941: case llvm::AMDGPU::GK_GFX940: From 98a9664915aed179b84d92ae39849178e6c4a7c3 Mon Sep 17 00:00:00 2001 From: Pablo Zimmermann Date: Tue, 3 Dec 2024 16:53:05 +0100 Subject: [PATCH 7/7] Generalize dot like op --- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 17 ++++++++--------- .../Transforms/FenceInsertion.cpp | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 3b29f73e1d7a..c5dc7e4d03ea 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -44,8 +44,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - int numWarps) { +SmallVector +warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) @@ -58,9 +58,8 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, auto slices = multiRootGetSlice(dotOp, {filter}, {filter}); bool hasChainedDot = false; for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) { - auto chainedDot = cast(op); - auto resTy = chainedDot.getResult().getType(); + if (op->hasTrait() && op != dotOp) { + auto resTy = cast(op->getResult(0).getType()); if (resTy.getRank() != rank) { continue; } @@ -109,14 +108,14 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, +warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); + mlir::getForwardSlice(dotOp->getResult(0), &slices); // Contains a chained dot. We prefer to assign warps to one axis // to facilitate use cases like flash attention, allowing reductions within // the same warp. - if (llvm::find_if(slices, [](Operation *op) { + if (llvm::find_if(slices, [&](Operation *op) { return op->hasTrait(); }) != slices.end()) return {(unsigned)numWarps, 1}; @@ -171,7 +170,7 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, } SmallVector -getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, +getWarpsPerTile(Operation *dotOp, const ArrayRef shape, int version, int numWarps, const SmallVector &instrShape) { switch (version) { case 2: diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fdb18..62ed71c175fd 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -44,7 +44,7 @@ struct FenceInsertionPass return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (!isa(op)) + if (!op->hasTrait()) return WalkResult::advance(); OpBuilder builder(op); auto a = op->getOperand(0);