Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][nvvm] Support predicates in BasicPtxBuilder #67102

Merged
merged 1 commit into from
Oct 17, 2023
Merged

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 22, 2023

This PR enhances BasicPtxBuilder to support predicates in PTX code generation. The BasicPtxBuilder interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism.

In PTX programming, instructions can be guarded by predicates as shown below:. Here @p is a predicate register and guard the execution of the instruction.

@p ptx.code op1, op2, op3

This PR introduces the getPredicate function in the BasicPtxBuilder interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition.

Additionally, this PR implements predicate usage for the following ops:

  • mbarrier.init
  • mbarrier.init.shared
  • mbarrier.arrive.expect_tx
  • mbarrier.arrive.expect_tx.shared
  • cp.async.bulk.tensor.shared.cluster.global
  • cp.async.bulk.tensor.global.shared.cta

See for more detail in PTX programing model
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-llvm

Changes

This PR enhances BasicPtxBuilder to support predicates in PTX code generation. The BasicPtxBuilder interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism.

In PTX programming, instructions can be guarded by predicates as shown below:. Here @<!-- -->p is a predicate register and guard the execution of the instruction.

@<!-- -->p ptx.code op1, op2, op3

This PR introduces the getPredicate function in the BasicPtxBuilder interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition.

Additionally, this PR implements predicate usage for the following ops:

  • mbarrier.init
  • mbarrier.init.shared
  • mbarrier.arrive.expect_tx
  • mbarrier.arrive.expect_tx.shared
  • cp.async.bulk.tensor.shared.cluster.global
  • cp.async.bulk.tensor.global.shared.cta

See for more detail in PTX programing model
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions


Patch is 27.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67102.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+93-47)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+8-6)
  • (modified) mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp (+9)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+46-12)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a528e015523e174..a1cfb305d8d5e50 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -86,6 +86,8 @@ class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
 // Basic PTX Builder Interface
 //===----------------------------------------------------------------------===//
 
+def PtxPredicate : Optional<I1>;
+
 // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
 def Read : I32EnumAttrCase<"Read", 0, "read">;
 def Write : I32EnumAttrCase<"Write", 2, "write">;
@@ -118,36 +120,49 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
     is started from the results and they are used as write, followed by the 
     operands and attributes.
 
+    `getPredicate` is an optional function for setting a predicate, which 
+    always returns a `PtxPredicate` value of type i1. If no predicate is 
+    provided, the instruction is unguarded; otherwise, it's guarded by the 
+    predicate value. The `PtxPredicate` value must always be the last argument. 
+    The provided PTX code by `getPtx` should not include the predicate usage.
+    The interface automatically handles predicate usage in the generated
+    PTX code when necessary.
+
     Example:
     If we have following Op definition that returns PTX code by `getPtx`. 
     
     ```tablegen
-      def NVVM_MyOp : NVVM_Op<"myop",
-          [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-        Results<(outs LLVM_Type:$res)>,
-        Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
-        ...
+      def NVVM_OpCode : NVVM_PTXBuilder_Op<"opcode">,  
+        Arguments<(ins I32:$op1, I32:$op2, PtxPredicate:$predicate)> {
+        let assemblyFormat = "$op1 `,` $op2 (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
         let extraClassDefinition = [{
-          std::string $cppClass::getPtx() { 
-            return std::string("my.ptx.code %0, %1, %2;"); 
-          }
-      } ];
+          std::string $cppClass::getPtx() { return std::string("opcode [%0], %1"); }
+        } ];
+      }
     ```
 
-    The NVVM Op will look like below:
+    The NVVM Op can be look like one of these:
     ```mlir
-      %0 = my.ptx.code %1, %2 : !llvm.ptr, i32 -> i32
+      %s1 = nvvm.opcode %1, %2 : i32, i32
+      %s2 = nvvm.opcode %1, %2, predicate = %p : i32, i32, i1
     ```
 
-    The `convert-nvvm-to-llvm` Pass generates the PTX code below. The order of 
-    arguments are kept the same. The read and write modifiers are set based on
-    the input and result types.
+    The `convert-nvvm-to-llvm` Pass generates PTX code with preserved argument 
+    order and sets read and write modifiers based on input and result types.
     ```mlir
-      %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.ptx.code %0, %1, %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32
+      %0 = llvm.inline_asm has_side_effects asm_dialect = att "my.opcode %0, %1;", "r, r" %0, %1 : (i32, i32)
+      %0 = llvm.inline_asm has_side_effects asm_dialect = att "@%2 my.opcode %0, %1;", "r, r, b" %0, %1, %p : (i32, i32, i1)
     ```
-
   }];
   let methods = [
+    InterfaceMethod<
+        /*desc=*/[{Returns an optional predicate value.}],
+        /*retType=*/"std::optional<::mlir::Value>",
+        /*methodName=*/"getPredicate",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
     InterfaceMethod<
         /*desc=*/[{
           Returns whether the operation has intrinsic support in LLVM.
@@ -211,6 +226,12 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
   ];
 }
 
+/// Base class that defines BasicPtxBuilderOpInterface. 
+class NVVM_PTXBuilder_Op<string mnemonic, 
+  list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
+  LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
@@ -334,21 +355,31 @@ def NVVM_ReduxOp :
 //===----------------------------------------------------------------------===//
 
 /// mbarrier.init instruction with generic pointer type
-def NVVM_MBarrierInitOp : NVVM_Op<"mbarrier.init">,
-  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count)> {
+def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init, {$addr, $count});
   }];
-  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let extraClassDeclaration = [{
+    bool hasIntrinsic() { if(getPredicate()) return false; return true; }
+  }];
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("mbarrier.init.b64 [%0], %1;"); }
+  }];
 }
 
 /// mbarrier.init instruction with shared pointer type
-def NVVM_MBarrierInitSharedOp : NVVM_Op<"mbarrier.init.shared">,
-  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count)> {
+def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$count, PtxPredicate:$predicate)> {
   string llvmBuilder = [{
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
   }];
-  let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)";
+  let assemblyFormat = "$addr `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() { return std::string("mbarrier.init.shared.b64 [%0], %1;"); }
+  }];
 }
 
 def NVVM_MBarrierInvalOp : NVVM_Op<"mbarrier.inval">,
@@ -403,26 +434,23 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx",
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> {
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,  
+  Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
-  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> {    
-  let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)";
+def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,  
+  Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount, PtxPredicate:$predicate)> {    
+  let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
   Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> {  
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -441,8 +469,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity",
   }];
 }
 
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,  
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
   Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> {  
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -596,7 +623,7 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
 
 def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
 
-def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
   Arguments<(ins LLVM_i8Ptr_shared:$dst,
                  LLVM_i8Ptr_global:$src,
                  I32Attr:$size,
@@ -1467,12 +1494,24 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_i64ptr_shared:$dstMem,
                   LLVM_i64ptr_any:$tmaDescriptor,
                   LLVM_i64ptr_shared:$mbar,
-                  Variadic<I32>:$coordinates)> {
-  let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $dstMem `,` 
+    $tmaDescriptor `,` 
+    $mbar `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict  `:` type(operands)
+  }];
+
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       int dim = getCoordinates().size();
@@ -1490,11 +1529,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
   let hasVerifier = 1;
 }
 
-def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
+  NVVM_Op<"cp.async.bulk.tensor.global.shared.cta", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+  AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_i64ptr_any:$tmaDescriptor,
                   LLVM_i64ptr_shared:$srcMem,
-                  Variadic<I32>:$coordinates)> {
-  let assemblyFormat = "$tmaDescriptor `,` $srcMem `,` `box` `[`$coordinates `]` attr-dict  `:` type(operands)";
+                  Variadic<I32>:$coordinates,
+                  PtxPredicate:$predicate)> {
+  let assemblyFormat = [{ 
+    $tmaDescriptor `,` 
+    $srcMem `,` 
+    `box` `[`$coordinates `]` 
+    (`,` `predicate` `=` $predicate^)?  
+    attr-dict  `:` type(operands)
+  }];
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       int dim = getCoordinates().size();
@@ -1516,8 +1565,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : NVVM_Op<"cp.async.bulk.tensor.gl
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -1531,8 +1579,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
   let description = [{
@@ -1545,8 +1592,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", 
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
   let arguments = (ins I32Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..1f866039e20e564 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include <optional>
 
 #define DEBUG_TYPE "nvgpu-to-nvvm"
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -793,9 +794,10 @@ struct NVGPUMBarrierInitLowering
 
     if (isMbarrierShared(op.getBarrier().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
-                                                              count);
+                                                              count, Value());
     } else {
-      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count);
+      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+                                                        Value());
     }
     return success();
   }
@@ -886,12 +888,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
 
     if (isMbarrierShared(op.getBarrier().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
-          op, barrier, txcount);
+          op, barrier, txcount, Value());
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(op, barrier,
-                                                                txcount);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
+        op, barrier, txcount, Value());
     return success();
   }
 };
@@ -939,7 +941,7 @@ struct NVGPUTmaAsyncLoadOpLowering
     }
 
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
-        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords);
+        op, dest, adaptor.getTensorMapDescriptor(), barrier, coords, Value());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 2d7a441e950045c..df3b1850e8d343f 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -63,6 +63,8 @@ class PtxBuilder {
 
   // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
   char getRegisterType(Type type) {
+    if (type.isInteger(1))
+      return 'b';
     if (type.isInteger(16))
       return 'h';
     if (type.isInteger(32))
@@ -158,6 +160,13 @@ class PtxBuilder {
         asmConstraints[asmConstraints.size() - 1] == ',')
       asmConstraints.pop_back();
 
+    // Add the predicate to the asm string.
+    if (op.getPredicate().has_value() && op.getPredicate().value()) {
+      std::string predicateStr = "@%";
+      predicateStr += std::to_string((asmVals.size() - 1));
+      asmStr = predicateStr + " " + asmStr;
+    }
+
     // asm keywords expects %, but inline assembly uses $. Replace all % with $
     std::replace(asmStr.begin(), asmStr.end(), '%', '$');
 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 7ffe1ad2bb2b111..228f249db0a0700 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -4,17 +4,30 @@
 // and the generic `convert-to-llvm` pass.
 // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
 
+// CHECK-LABEL: @init_mbarrier
+llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %count : i32, %pred : i1) {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.shared.b64 [$0], $1;", "r,r,b" 
+  nvvm.mbarrier.init.shared %barrier, %count, predicate = %pred : !llvm.ptr<3>, i32, i1 
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" 
+  nvvm.mbarrier.init %barrier_gen, %count, predicate = %pred : !llvm.ptr, i32, i1
+  llvm.return
+}
+
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
-llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
   nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+  //CHECK :  llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b "
+  nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 
   llvm.return
 }
 
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx_generic
-llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) {
+llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32, %pred : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" 
   nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r,b"
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr, i32, i1 
   llvm.return
 }
 
@@ -51,72 +64,93 @@ func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
 }
 
 // CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32) {
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3} ], [$2];", "r,l,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3}], [$2];", "l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0], predicate=%p : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32,i1
   return
 }
 
 // CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32) {
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4} ], [$2];", "r,l,r,r,r"
   nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32
+  // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4}],...
[truncated]

This PR enhances `BasicPtxBuilder` to support predicates in PTX code generation. The `BasicPtxBuilder` interface was initially introduced for generating PTX code automatically for Ops that aren't supported by LLVM core. Predicates, which are typically not supported in LLVM core, are now supported using the same mechanism.

In PTX programming, instructions can be guarded by predicates as shown below:. Here `@p` is a predicate register and guard the execution of the instruction.

```
@p ptx.code op1, op2, op3
```

This PR introduces the `getPredicate` function in the `BasicPtxBuilder` interface to set an optional predicate. When a predicate is provided, the instruction is generated with predicate and guarded, otherwise, predicate is not genearted. Note that the predicate value must always appear as the last argument on the Op definition.

Additionally, this PR implements predicate usage for the following ops:

- mbarrier.init
- mbarrier.init.shared
- mbarrier.arrive.expect_tx
- mbarrier.arrive.expect_tx.shared
- cp.async.bulk.tensor.shared.cluster.global
- cp.async.bulk.tensor.global.shared.cta

See for more detail in PTX programing model
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions
@grypp grypp merged commit 6338932 into llvm:main Oct 17, 2023
@grypp grypp deleted the add_pred branch October 17, 2023 10:42
grypp added a commit to grypp/llvm-project that referenced this pull request Oct 17, 2023
llvm#67102 introduced predication support in BasicPtxBuilderInterface. The predication is available for any NVVM ops just like PTX.

This PR introduces predicate arguments to the following NVGPU Ops. We pass this argument to the BasicPtxBuilderInterface.

- mbarrier.init
- mbarrier.arrive.expect_tx
- tma.async.load
@dcaballe
Copy link
Contributor

Nobody reviewing these PRs before merging?

@grypp
Copy link
Member Author

grypp commented Oct 18, 2023

I actually presented this work multiple times internally. The logic is very mechanical, it does what PTX instruction expects, so not complex. After the PR sitting there for nearly a month, I finally went ahead and merged it.

I should've reached out to someone for a review or even just a quick glance. My bad

@joker-eph
Copy link
Collaborator

I actually presented this work multiple times internally.

It would be great to do this externally instead if possible?

@grypp
Copy link
Member Author

grypp commented Oct 19, 2023

I actually presented this work multiple times internally.

It would be great to do this externally instead if possible?

Agreed.

Speaking up external presentation, perhaps I can present of the advancements in the NVGPU and NVVM dialects for Hopper at the MLIR Open Design Meeting.

@joker-eph
Copy link
Collaborator

That looks like a great idea! There is already something planned next week, but can you add yourself here for 11/2: https://docs.google.com/document/d/1y2YlcOVMPocQjSFi3X6gYGRjA0onyqr41ilXji10phw/edit#heading=h.cite1kolful9 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants