Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISel] Add ISel supports for SHXADD from Zba extension #67863

Merged
merged 9 commits into from
Oct 18, 2023

Conversation

mshockwave
Copy link
Member

@mshockwave mshockwave commented Sep 29, 2023

This patch consists of porting SDISel patterns of SHXADD instructions to GISel.
Note that non_imm12, a predicate that was implemented with PatLeaf, is now turned into a PatFrag of <op>_with_non_imm12 where op is the operator that uses the non_imm12 operand, as GISel doesn't have equivalence of PatLeaf at this moment.

I'll put ISel supports for SHXADD_UW in a separate patch since the current one is getting a little big.

This patch constitue of porting (SDISel) patterns of SHXADD
instructions.
Note that `non_imm12`, a predicate that was implemented with `PatLeaf`,
is now turned into a ComplexPattern to facilitate code reusing on
patterns that use it between SDISel and GISel.
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2023

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-risc-v

Changes

This patch constitue of porting (SDISel) patterns of SHXADD instructions.
Note that non_imm12, a predicate that was implemented with PatLeaf, is now turned into a ComplexPattern to facilitate code reusing on patterns that use it between SDISel and GISel, as GISel doesn't have equivalence of PatLeaf at this moment.

I'll put ISel supports for SHXADD_UW in a separate patch since the current one is getting a little big.


Patch is 23.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67863.diff

7 Files Affected:

  • (modified) llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp (+130)
  • (modified) llvm/lib/Target/RISCV/RISCVGISel.td (+10)
  • (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+9)
  • (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h (+2)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZb.td (+24-27)
  • (added) llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir (+152)
  • (added) llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir (+152)
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index 4f97a0d84f686f9..3a98e84546f376f 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -17,6 +17,7 @@
 #include "RISCVTargetMachine.h"
 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
 #include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
+#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/IR/IntrinsicsRISCV.h"
 #include "llvm/Support/Debug.h"
@@ -55,6 +56,14 @@ class RISCVInstructionSelector : public InstructionSelector {
 
   ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
 
+  ComplexRendererFns selectNonImm12(MachineOperand &Root) const;
+
+  ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const;
+  template <unsigned ShAmt>
+  ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const {
+    return selectSHXADDOp(Root, ShAmt);
+  }
+
   // Custom renderers for tablegen
   void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI,
                     int OpIdx) const;
@@ -105,6 +114,127 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
   return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
 }
 
+// This complex pattern actually serves as a perdicate that is effectively
+// `!isInt<12>(Imm)`.
+InstructionSelector::ComplexRendererFns
+RISCVInstructionSelector::selectNonImm12(MachineOperand &Root) const {
+  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  if (Root.isReg() && Root.getReg())
+    if (auto Val = getIConstantVRegValWithLookThrough(Root.getReg(), MRI)) {
+      // We do NOT want immediates that fit in 12 bits.
+      if (isInt<12>(Val->Value.getSExtValue()))
+        return std::nullopt;
+    }
+
+  return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
+}
+
+InstructionSelector::ComplexRendererFns
+RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
+                                         unsigned ShAmt) const {
+  using namespace llvm::MIPatternMatch;
+  MachineFunction &MF = *Root.getParent()->getParent()->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  if (!Root.isReg())
+    return std::nullopt;
+  Register RootReg = Root.getReg();
+
+  const unsigned XLen = STI.getXLen();
+  APInt Mask, C2;
+  Register RegY;
+  std::optional<bool> LeftShift;
+  // (and (shl y, c2), mask)
+  if (mi_match(RootReg, MRI,
+               m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
+    LeftShift = true;
+  // (and (lshr y, c2), mask)
+  else if (mi_match(RootReg, MRI,
+                    m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
+    LeftShift = false;
+
+  if (LeftShift.has_value()) {
+    if (*LeftShift)
+      Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
+    else
+      Mask &= maskTrailingOnes<uint64_t>(XLen - C2.getLimitedValue());
+
+    if (Mask.isShiftedMask()) {
+      unsigned Leading = XLen - Mask.getActiveBits();
+      unsigned Trailing = Mask.countr_zero();
+      // Given (and (shl y, c2), mask) in which mask has no leading zeros and c3
+      // trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
+      if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
+              .addImm(Trailing - C2.getLimitedValue());
+          MIB.addReg(DstReg);
+        }}};
+      }
+
+      // Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and c3
+      // trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
+      if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SRLI, {DstReg}, {RegY})
+              .addImm(Leading + Trailing);
+          MIB.addReg(DstReg);
+        }}};
+      }
+    }
+  }
+
+  LeftShift.reset();
+
+  // (shl (and y, mask), c2)
+  if (mi_match(RootReg, MRI,
+               m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
+                      m_ICst(C2))))
+    LeftShift = true;
+  // (lshr (and y, mask), c2)
+  else if (mi_match(RootReg, MRI,
+                    m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
+                            m_ICst(C2))))
+    LeftShift = false;
+
+  if (LeftShift.has_value())
+    if (Mask.isShiftedMask()) {
+      unsigned Leading = XLen - Mask.getActiveBits();
+      unsigned Trailing = Mask.countr_zero();
+
+      // Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
+      // c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
+      bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
+                  (Trailing + C2.getLimitedValue()) == ShAmt;
+      if (!Cond)
+        // Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
+        // c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
+        Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
+               (Trailing - C2.getLimitedValue()) == ShAmt;
+
+      if (Cond) {
+        Register DstReg =
+            MRI.createGenericVirtualRegister(MRI.getType(RootReg));
+        return {{[=](MachineInstrBuilder &MIB) {
+          MachineIRBuilder(*MIB.getInstr())
+              .buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
+              .addImm(Trailing);
+          MIB.addReg(DstReg);
+        }}};
+      }
+    }
+
+  return std::nullopt;
+}
+
 // Tablegen doesn't allow us to write SRLIW/SRAIW/SLLIW patterns because the
 // immediate Operand has type XLenVT. GlobalISel wants it to be i32.
 bool RISCVInstructionSelector::earlySelectShift(
diff --git a/llvm/lib/Target/RISCV/RISCVGISel.td b/llvm/lib/Target/RISCV/RISCVGISel.td
index 8059b517f26ba3c..2d6a293c2cca148 100644
--- a/llvm/lib/Target/RISCV/RISCVGISel.td
+++ b/llvm/lib/Target/RISCV/RISCVGISel.td
@@ -31,6 +31,16 @@ def ShiftMaskGI :
     GIComplexOperandMatcher<s32, "selectShiftMask">,
     GIComplexPatternEquiv<shiftMaskXLen>;
 
+def gi_non_imm12 : GIComplexOperandMatcher<s32, "selectNonImm12">,
+                   GIComplexPatternEquiv<non_imm12>;
+
+def gi_sh1add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<1>">,
+                   GIComplexPatternEquiv<sh1add_op>;
+def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
+                   GIComplexPatternEquiv<sh2add_op>;
+def gi_sh3add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<3>">,
+                   GIComplexPatternEquiv<sh3add_op>;
+
 // FIXME: Canonicalize (sub X, C) -> (add X, -C) earlier.
 def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
           (ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 70b9041852f91f8..de04f4c12e5e8e2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -2443,6 +2443,15 @@ bool RISCVDAGToDAGISel::SelectAddrRegImm(SDValue Addr, SDValue &Base,
   return true;
 }
 
+bool RISCVDAGToDAGISel::selectNonImm12(SDValue N, SDValue &Opnd) {
+  auto *C = dyn_cast<ConstantSDNode>(N);
+  if (!C || !isInt<12>(C->getSExtValue())) {
+    Opnd = N;
+    return true;
+  }
+  return false;
+}
+
 bool RISCVDAGToDAGISel::selectShiftMask(SDValue N, unsigned ShiftWidth,
                                         SDValue &ShAmt) {
   ShAmt = N;
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
index c220b2d57c2e50f..d3d095a370683df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
@@ -83,6 +83,8 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
   bool trySignedBitfieldExtract(SDNode *Node);
   bool tryIndexedLoad(SDNode *Node);
 
+  bool selectNonImm12(SDValue N, SDValue &Opnd);
+
   bool selectShiftMask(SDValue N, unsigned ShiftWidth, SDValue &ShAmt);
   bool selectShiftMaskXLen(SDValue N, SDValue &ShAmt) {
     return selectShiftMask(N, Subtarget->getXLen(), ShAmt);
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
index a21c3d132636bea..c20c3176bb27dbc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
@@ -235,10 +235,7 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
 }]>;
 
 // Pattern to exclude simm12 immediates from matching.
-def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
-  auto *C = dyn_cast<ConstantSDNode>(N);
-  return !C || !isInt<12>(C->getSExtValue());
-}]>;
+def non_imm12 : ComplexPattern<XLenVT, 1, "selectNonImm12", [], [], 0>;
 
 def Shifted32OnesMask : PatLeaf<(imm), [{
   uint64_t Imm = N->getZExtValue();
@@ -651,19 +648,19 @@ let Predicates = [HasStdExtZbb, IsRV64] in
 def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>;
 
 let Predicates = [HasStdExtZba] in {
-def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), non_imm12:$rs2),
+def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), (non_imm12 (XLenVT GPR:$rs2))),
           (SH1ADD GPR:$rs1, GPR:$rs2)>;
-def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), non_imm12:$rs2),
+def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), (non_imm12 (XLenVT GPR:$rs2))),
           (SH2ADD GPR:$rs1, GPR:$rs2)>;
-def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), non_imm12:$rs2),
+def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), (non_imm12 (XLenVT GPR:$rs2))),
           (SH3ADD GPR:$rs1, GPR:$rs2)>;
 
 // More complex cases use a ComplexPattern.
-def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add sh1add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
           (SH1ADD sh1add_op:$rs1, GPR:$rs2)>;
-def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add sh2add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
           (SH2ADD sh2add_op:$rs1, GPR:$rs2)>;
-def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
+def : Pat<(add sh3add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
           (SH3ADD sh3add_op:$rs1, GPR:$rs2)>;
 
 def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
@@ -735,48 +732,48 @@ def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)),
           (SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask),
                    Shifted32OnesMask:$mask)>;
 
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (ADD_UW GPR:$rs1, GPR:$rs2)>;
 def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>;
 
-def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (ADD_UW GPR:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), non_imm12:$rs2)),
+def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), non_imm12:$rs2)),
+def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), non_imm12:$rs2)),
+def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)),
+def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
 
 // More complex cases use a ComplexPattern.
-def : Pat<(i64 (add sh1add_uw_op:$rs1, non_imm12:$rs2)),
+def : Pat<(i64 (add sh1add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
+def : Pat<(i64 (add sh2add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>;
-def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
+def : Pat<(i64 (add sh3add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;
 
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>;
 
 // Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
-def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>;
-def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)),
+def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
           (SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>;
 
 def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
diff --git a/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir
new file mode 100644
index 000000000000000..f90de3ea55a1bb7
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv32.mir
@@ -0,0 +1,152 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 3
+# RUN: llc -mtriple=riscv32 -mattr='+zba' -run-pass=instruction-select -simplify-mir -verify-machineinstrs %s -o - \
+# RUN: | FileCheck %s
+
+---
+name:            sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH1ADD:%[0-9]+]]:gpr = SH1ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH1ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+    %2:gprb(s32) = G_CONSTANT i32 1
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            sh2add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh2add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+    %2:gprb(s32) = G_CONSTANT i32 2
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            sh3add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh3add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH3ADD:%[0-9]+]]:gpr = SH3ADD [[COPY]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH3ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+    %2:gprb(s32) = G_CONSTANT i32 3
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            no_sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: no_sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[SLLI:%[0-9]+]]:gpr = SLLI [[COPY]], 1
+    ; CHECK-NEXT: [[ADDI:%[0-9]+]]:gpr = ADDI [[SLLI]], 37
+    ; CHECK-NEXT: $x10 = COPY [[ADDI]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = G_CONSTANT i32 37
+    %2:gprb(s32) = G_CONSTANT i32 1
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_ADD %3, %1
+    $x10 = COPY %4(s32)
+...
+---
+name:            shXadd_complex_shl_and
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_complex_shl_and
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SRLI:%[0-9]+]]:gpr = SRLI [[COPY]], 1
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[SRLI]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+
+    %2:gprb(s32) = G_CONSTANT i32 1
+    %3:gprb(s32) = G_SHL %0, %2
+    %4:gprb(s32) = G_CONSTANT i32 4294967292
+    %5:gprb(s32) = G_AND %3, %4
+
+    %6:gprb(s32) = G_ADD %5, %1
+    $x10 = COPY %6(s32)
+...
+---
+name:            shXadd_complex_lshr_and
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: shXadd_complex_lshr_and
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SRLI:%[0-9]+]]:gpr = SRLI [[COPY]], 29
+    ; CHECK-NEXT: [[SH2ADD:%[0-9]+]]:gpr = SH2ADD [[SRLI]], [[COPY1]]
+    ; CHECK-NEXT: $x10 = COPY [[SH2ADD]]
+    %0:gprb(s32) = COPY $x10
+    %1:gprb(s32) = COPY $x11
+
+    %2:gprb(s32) = G_CONSTANT i32 27
+    %3:gprb(s32) = G_LSHR %0, %2
+    %4:gprb(s32) = G_CONSTANT i32 60
+    %5:gprb(s32) = G_AND %3, %4
+
+    %6:gprb(s32) = G_ADD %5, %1
+    $x10 = COPY %6(s32)
+...
diff --git a/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
new file mode 100644
index 000000000000000..092a3305b3453d2
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/GlobalISel/instruction-select/zba-rv64.mir
@@ -0,0 +1,152 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 3
+# RUN: llc -mtriple=riscv64 -mattr='+zba' -run-pass=instruction-select -simplify-mir -verify-machineinstrs %s -o - \
+# RUN: | FileCheck %s
+
+---
+name:            sh1add
+legalized:       true
+regBankSelected: true
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $x10, $x11
+
+    ; CHECK-LABEL: name: sh1add
+    ; CHECK: liveins: $x10, $x11
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:gpr = COPY $x10
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:gpr = COPY $x11
+    ; CHECK-NEXT: [[SH1ADD:%[0-9]+]]:gpr = SH1ADD [[COPY]], [[COPY1]]
+ ...
[truncated]

if (Mask.isShiftedMask()) {
unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();
// Given (and (shl y, c2), mask) in which mask has no leading zeros and c3
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm copying the comments from RISCVISelDAGToDAG.cpp to here for better readability.

@github-actions
Copy link

github-actions bot commented Sep 29, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

auto *C = dyn_cast<ConstantSDNode>(N);
return !C || !isInt<12>(C->getSExtValue());
}]>;
def non_imm12 : ComplexPattern<XLenVT, 1, "selectNonImm12", [], [], 0>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add let GISelPredicateCode here to do this without changing to a complex pattern?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some digging, I think the answer is no and I'm sad about it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, using GISelPredicateCode was my first approach until I found that the GlobalISelEmitter TG backend doesn't pick that up for leaf nodes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add GISelPredicateCode to something like this work topperc@01205c1

Copy link
Member Author

@mshockwave mshockwave Oct 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add GISelPredicateCode to something like this work topperc@01205c1

This works, though I'm a little concerned that this might create too many boilerplate code in the future, since there needs to be a Predicate TG record for every opcode that goes with non_imm12 (even we abstract the real predicate logics into a function). What do you think?

Also, interestingly GISelPredicateCode doesn't dance well with PredicateCodeUsesOperands: it SEGFAULT llvm-tblgen in our case, despite the fact that there are tests for this exact combination (in test/TableGen/GlobalISelEmitterCustomPredicate.td). I can't find an obvious fix for llvm-tblgen but writing a non-PredicateCodeUsesOperands predicate code works so I'm not too bothered by this crash (for now).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used PredicateCodeUsesOperands so that I could know which operand wasn't the shl since add is commutable and tblgen will generate both patterns.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add GISelPredicateCode to something like this work topperc@01205c1

Done: it's no longer using ComplexPattern of non_imm12 but PatFrag of <op>_with_non_imm12 instead

But since there is a bug in llvm-tblgen that crashes itself whenever a
ComplexPattern failed to be imported with `PredicateUsesOperands` +
`GISelPredicateCode`, we preserve the original `non_imm12` (PatLeaf) and
leave all `SHXADD_UW` patterns untouched.
@mshockwave
Copy link
Member Author

mshockwave commented Oct 3, 2023

Note that this patch depends on #68125 so the buildbot failures here are expected.

LeftShift = false;

if (LeftShift.has_value())
if (Mask.isShiftedMask()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this condition with the previous if?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this comment addressed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is addressed now

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mshockwave mshockwave merged commit 5f5faf4 into llvm:main Oct 18, 2023
2 checks passed
@mshockwave mshockwave deleted the gisel-riscv-zba-shift branch October 18, 2023 22:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants