-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][nvvm] Support predicates in BasicPtxBuilder
#67102
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm ChangesThis PR enhances In PTX programming, instructions can be guarded by predicates as shown below:. Here
This PR introduces the Additionally, this PR implements predicate usage for the following ops:
See for more detail in PTX programing model Patch is 27.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67102.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a528e015523e174..a1cfb305d8d5e50 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -86,6 +86,8 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
// Basic PTX Builder Interface
//===----------------------------------------------------------------------===//
+def PtxPredicate : Optional<I1>;
+
// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
def Read : I32EnumAttrCase<"Read", 0, "read">;
def Write : I32EnumAttrCase<"Write", 2, "write">;
@@ -118,36 +120,49 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
is started from the results and they are used as write, followed by the
operands and attributes.
+ `getPredicate` is an optional function for setting a predicate, which
+ always returns a `PtxPredicate` value of type i1. If no predicate is
+ provided, the instruction is unguarded; otherwise, it's guarded by the
+ predicate value. The `PtxPredicate` value must always be the last argument.
+ The provided PTX code by `getPtx` should not include the predicate usage.
+ The interface automatically handles predicate usage in the generated
+ PTX code when necessary.
+
Example:
If we have following Op definition that returns PTX code by `getPtx`.
```tablegen
- def NVVM_MyOp : NVVM_Op<"myop",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Results<(outs LLVM_Type:$res)>,
- Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
- ...
+ def NVVM_OpCode : NVVM_PTXBuilder_Op<"opcode">,
+ Arguments<(ins I32:$op1, I32:$op2, PtxPredicate:$predicate)> {
+ let assemblyFormat = "$op1 `,` $op2 (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string("my.ptx.code %0, %1, %2;");
- }
- } ];
+ std::string $cppClass::getPtx() { return std::string("opcode [%0], %1"); }
+ } ];
+ }
```
- The NVVM Op will look like below:
+ The NVVM Op can be look like one of these:
```mlir
- %0 = my.ptx.code %1, %2 : !llvm.ptr, i32 -> i32
+ %s1 = nvvm.opcode %1, %2 : i32, i32
+ %s2 = nvvm.opcode %1, %2, predicate = %p : i32, i32, i1
```
- The `convert-nvvm-to-llvm` Pass generates the PTX code below. The order of
- arguments are kept the same. The read and write modifiers are set based on
- the input and result types.
+ The `convert-nvvm-to-llvm` Pass generates PTX code with preserved argument
+ order and sets read and write modifiers based on input and result types.
```mlir
- %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.ptx.code %0, %1, %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+ %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.opcode %0, %1;", "r, r" %0, %1 : (i32, i32)
+ %0 = llvm.inline_asm has_side_effects asm_dialect = att "@%2 my.opcode %0, %1;", "r, r, b" %0, %1, %p : (i32, i32, i1)
```
-
}];
let methods = [
+ InterfaceMethod<
+ /*desc=*/[{Returns an optional predicate value.}],
+ /*retType=*/"std::optional<::mlir::Value>",
+ /*methodName=*/"getPredicate",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
InterfaceMethod<
/*desc=*/[{
Returns whether the operation has intrinsic support in LLVM.
@@ -211,6 +226,12 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
];
}
+/// Base class that defines BasicPtxBuilderOpInterface.
+class NVVM_PTXBuilder_Op<string mnemonic,
+ list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
+ LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
+}
+
//===----------------------------------------------------------------------===//
// NVVM intrinsic operations
//===----------------------------------------------------------------------===//
@@ -334,21 +355,31 @@ def NVVM_ReduxOp :
//===----------------------------------------------------------------------===//
/// mbarrier.init instruction with generic pointer type
-def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
- Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
+def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
}];
- let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+ let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+ let extraClassDeclaration = [{
+ bool hasIntrinsic() { if(getPredicate()) return false; return true; }
+ }];
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
+ }];
}
/// mbarrier.init instruction with shared pointer type
-def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
- Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
+def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
}];
- let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+ let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+ let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
+ }];
}
def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -403,26 +434,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
}
-def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
+ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+ let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
}];
}
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
- Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {
- let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
+ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+ let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
}];
}
-def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
@@ -441,8 +469,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
}];
}
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
@@ -596,7 +623,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
-def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
Arguments<(ins LLVM_i8Ptr_shared:$dst,
LLVM_i8Ptr_global:$src,
I32Attr:$size,
@@ -1467,12 +1494,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
-def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
+ NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
Arguments<(ins LLVM_i64ptr_shared:$dstMem,
LLVM_i64ptr_any:$tmaDescriptor,
LLVM_i64ptr_shared:$mbar,
- Variadic<I32>:$coordinates)> {
- let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $dstMem `,`
+ $tmaDescriptor `,`
+ $mbar `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
+
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int dim = getCoordinates().size();
@@ -1490,11 +1529,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
let hasVerifier = 1;
}
-def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
+ NVVM_Op<"cp.async.bulk.tensor.global.shared.cta",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>,
Arguments<(ins LLVM_i64ptr_any:$tmaDescriptor,
LLVM_i64ptr_shared:$srcMem,
- Variadic<I32>:$coordinates)> {
- let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)";
+ Variadic<I32>:$coordinates,
+ PtxPredicate:$predicate)> {
+ let assemblyFormat = [{
+ $tmaDescriptor `,`
+ $srcMem `,`
+ `box` `[`$coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type(operands)
+ }];
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int dim = getCoordinates().size();
@@ -1516,8 +1565,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -1531,8 +1579,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let description = [{
@@ -1545,8 +1592,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
}];
}
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
let arguments = (ins I32Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..1f866039e20e564 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -22,6 +22,7 @@
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -793,9 +794,10 @@ struct NVGPUMBarrierInitLowering
if (isMbarrierShared(op.getBarrier().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
- count);
+ count, Value());
} else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+ Value());
}
return success();
}
@@ -886,12 +888,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
if (isMbarrierShared(op.getBarrier().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount);
+ op, barrier, txcount, Value());
return success();
}
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
- txcount);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
+ op, barrier, txcount, Value());
return success();
}
};
@@ -939,7 +941,7 @@ struct NVGPUTmaAsyncLoadOpLowering
}
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
- op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
+ op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
return success();
}
};
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 2d7a441e950045c..df3b1850e8d343f 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -63,6 +63,8 @@ class PtxBuilder {
// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
char getRegisterType(Type type) {
+ if (type.isInteger(1))
+ return 'b';
if (type.isInteger(16))
return 'h';
if (type.isInteger(32))
@@ -158,6 +160,13 @@ class PtxBuilder {
asmConstraints[asmConstraints.size() - 1] == ',')
asmConstraints.pop_back();
+ // Add the predicate to the asm string.
+ if (op.getPredicate().has_value() && op.getPredicate().value()) {
+ std::string predicateStr = "@%";
+ predicateStr += std::to_string((asmVals.size() - 1));
+ asmStr = predicateStr + " " + asmStr;
+ }
+
// asm keywords expects %, but inline assembly uses $. Replace all % with $
std::replace(asmStr.begin(), asmStr.end(), '%', '$');
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 7ffe1ad2bb2b111..228f249db0a0700 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -4,17 +4,30 @@
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
+// CHECK-LABEL: @init_mbarrier
+llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b"
+ nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
+ nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
+ llvm.return
+}
+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+ //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b "
+ nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
llvm.return
}
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r"
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1
llvm.return
}
@@ -51,72 +64,93 @@ func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
}
// CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32) {
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3}], [$2];", "l,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32,i1
return
}
// CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r"
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
+ // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4}],...
[truncated]
|
This PR enhances `BasicPtxBuilder` to support predicates in PTX code generation. The `BasicPtxBuilder` interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism. In PTX programming, instructions can be guarded by predicates as shown below:. Here `@p` is a predicate register and guard the execution of the instruction. ``` @p ptx.code op1, op2, op3 ``` This PR introduces the `getPredicate` function in the `BasicPtxBuilder` interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition. Additionally, this PR implements predicate usage for the following ops: - mbarrier.init - mbarrier.init.shared - mbarrier.arrive.expect_tx - mbarrier.arrive.expect_tx.shared - cp.async.bulk.tensor.shared.cluster.global - cp.async.bulk.tensor.global.shared.cta See for more detail in PTX programing model https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions
llvm#67102 introduced predication support in BasicPtxBuilderInterface. The predication is available for any NVVM ops just like PTX. This PR introduces predicate arguments to the following NVGPU Ops. We pass this argument to the BasicPtxBuilderInterface. - mbarrier.init - mbarrier.arrive.expect_tx - tma.async.load
Nobody reviewing these PRs before merging? |
I actually presented this work multiple times internally. The logic is very mechanical, it does what PTX instruction expects, so not complex. After the PR sitting there for nearly a month, I finally went ahead and merged it. I should've reached out to someone for a review or even just a quick glance. My bad |
It would be great to do this externally instead if possible? |
Agreed. Speaking up external presentation, perhaps I can present of the advancements in the NVGPU and NVVM dialects for Hopper at the MLIR Open Design Meeting. |
That looks like a great idea! There is already something planned next week, but can you add yourself here for 11/2: https://docs.google.com/document/d/1y2YlcOVMPocQjSFi3X6gYGRjA0onyqr41ilXji10phw/edit#heading=h.cite1kolful9 ? |
This PR enhances
BasicPtxBuilder
to support predicates in PTX code generation. TheBasicPtxBuilder
interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism.In PTX programming, instructions can be guarded by predicates as shown below:. Here
@p
is a predicate register and guard the execution of the instruction.This PR introduces the
getPredicate
function in theBasicPtxBuilder
interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition.Additionally, this PR implements predicate usage for the following ops:
See for more detail in PTX programing model
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions