Skip to content

Commit

Permalink
[RISCV][GISel] Add ISel supports for SHXADD from Zba extension (#67863)
Browse files Browse the repository at this point in the history
This patch consists of porting SDISel patterns of SHXADD instructions to
GISel.
Note that `non_imm12`, a predicate that was implemented with `PatLeaf`,
is now turned into a `PatFrag` of `<op>_with_non_imm12` where `op` is
the operator that uses `the non_imm12` operand, as GISel doesn't have
equivalence of `PatLeaf` at this moment.
  • Loading branch information
mshockwave authored Oct 18, 2023
1 parent 85b8958 commit 5f5faf4
Show file tree
Hide file tree
Showing 5 changed files with 474 additions and 31 deletions.
109 changes: 109 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -68,6 +69,12 @@ class RISCVInstructionSelector : public InstructionSelector {
ComplexRendererFns selectShiftMask(MachineOperand &Root) const;
ComplexRendererFns selectAddrRegImm(MachineOperand &Root) const;

ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const;
template <unsigned ShAmt>
ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const {
return selectSHXADDOp(Root, ShAmt);
}

// Custom renderers for tablegen
void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI,
int OpIdx) const;
Expand Down Expand Up @@ -122,6 +129,108 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
}

InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
unsigned ShAmt) const {
using namespace llvm::MIPatternMatch;
MachineFunction &MF = *Root.getParent()->getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();

if (!Root.isReg())
return std::nullopt;
Register RootReg = Root.getReg();

const unsigned XLen = STI.getXLen();
APInt Mask, C2;
Register RegY;
std::optional<bool> LeftShift;
// (and (shl y, c2), mask)
if (mi_match(RootReg, MRI,
m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
LeftShift = true;
// (and (lshr y, c2), mask)
else if (mi_match(RootReg, MRI,
m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
LeftShift = false;

if (LeftShift.has_value()) {
if (*LeftShift)
Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
else
Mask &= maskTrailingOnes<uint64_t>(XLen - C2.getLimitedValue());

if (Mask.isShiftedMask()) {
unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();
// Given (and (shl y, c2), mask) in which mask has no leading zeros and
// c3 trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLI, {DstReg}, {RegY})
.addImm(Trailing - C2.getLimitedValue());
MIB.addReg(DstReg);
}}};
}

// Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and
// c3 trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLI, {DstReg}, {RegY})
.addImm(Leading + Trailing);
MIB.addReg(DstReg);
}}};
}
}
}

LeftShift.reset();

// (shl (and y, mask), c2)
if (mi_match(RootReg, MRI,
m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
m_ICst(C2))))
LeftShift = true;
// (lshr (and y, mask), c2)
else if (mi_match(RootReg, MRI,
m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
m_ICst(C2))))
LeftShift = false;

if (LeftShift.has_value() && Mask.isShiftedMask()) {
unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();

// Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
// c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
(Trailing + C2.getLimitedValue()) == ShAmt;
if (!Cond)
// Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
// c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
(Trailing - C2.getLimitedValue()) == ShAmt;

if (Cond) {
Register DstReg = MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
.addImm(Trailing);
MIB.addReg(DstReg);
}}};
}
}

return std::nullopt;
}

InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
// TODO: Need to get the immediate from a G_PTR_ADD. Should this be done in
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/RISCV/RISCVGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def ShiftMaskGI :
GIComplexOperandMatcher<s32, "selectShiftMask">,
GIComplexPatternEquiv<shiftMaskXLen>;

def gi_sh1add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<1>">,
GIComplexPatternEquiv<sh1add_op>;
def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
GIComplexPatternEquiv<sh2add_op>;
def gi_sh3add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<3>">,
GIComplexPatternEquiv<sh3add_op>;

// FIXME: Canonicalize (sub X, C) -> (add X, -C) earlier.
def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
(ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
Expand Down
85 changes: 54 additions & 31 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,39 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
}]>;

// Pattern to exclude simm12 immediates from matching.
// Note: this will be removed once the GISel complex patterns for
// SHXADD_UW is landed.
def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
auto *C = dyn_cast<ConstantSDNode>(N);
return !C || !isInt<12>(C->getSExtValue());
}]>;

// GISel currently doesn't support PatFrag for leaf nodes, so `non_imm12`
// cannot be directly supported in GISel. To reuse patterns between the two
// ISels, we instead create PatFrag on operators that use `non_imm12`.
class binop_with_non_imm12<SDPatternOperator binop>
: PatFrag<(ops node:$x, node:$y), (binop node:$x, node:$y), [{
auto *C = dyn_cast<ConstantSDNode>(Operands[1]);
return !C || !isInt<12>(C->getSExtValue());
}]> {
let PredicateCodeUsesOperands = 1;
let GISelPredicateCode = [{
const MachineOperand &ImmOp = *Operands[1];
const MachineFunction &MF = *MI.getParent()->getParent();
const MachineRegisterInfo &MRI = MF.getRegInfo();

if (ImmOp.isReg() && ImmOp.getReg())
if (auto Val = getIConstantVRegValWithLookThrough(ImmOp.getReg(), MRI)) {
// We do NOT want immediates that fit in 12 bits.
return !isInt<12>(Val->Value.getSExtValue());
}

return true;
}];
}
def add_non_imm12 : binop_with_non_imm12<add>;
def or_is_add_non_imm12 : binop_with_non_imm12<or_is_add>;

def Shifted32OnesMask : PatLeaf<(imm), [{
uint64_t Imm = N->getZExtValue();
if (!isShiftedMask_64(Imm))
Expand Down Expand Up @@ -647,20 +675,17 @@ let Predicates = [HasStdExtZbb, IsRV64] in
def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>;

let Predicates = [HasStdExtZba] in {
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), non_imm12:$rs2),
(SH1ADD GPR:$rs1, GPR:$rs2)>;
def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), non_imm12:$rs2),
(SH2ADD GPR:$rs1, GPR:$rs2)>;
def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), non_imm12:$rs2),
(SH3ADD GPR:$rs1, GPR:$rs2)>;

// More complex cases use a ComplexPattern.
def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2),
(SH1ADD sh1add_op:$rs1, GPR:$rs2)>;
def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
(SH2ADD sh2add_op:$rs1, GPR:$rs2)>;
def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
(SH3ADD sh3add_op:$rs1, GPR:$rs2)>;
foreach i = {1,2,3} in {
defvar shxadd = !cast<Instruction>("SH"#i#"ADD");
def : Pat<(XLenVT (add_non_imm12 (shl GPR:$rs1, (XLenVT i)), GPR:$rs2)),
(shxadd GPR:$rs1, GPR:$rs2)>;

defvar pat = !cast<ComplexPattern>("sh"#i#"add_op");
// More complex cases use a ComplexPattern.
def : Pat<(XLenVT (add_non_imm12 pat:$rs1, GPR:$rs2)),
(shxadd pat:$rs1, GPR:$rs2)>;
}

def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
(SH1ADD (SH1ADD GPR:$rs1, GPR:$rs1), GPR:$rs2)>;
Expand Down Expand Up @@ -730,26 +755,24 @@ def : Pat<(i64 (shl (and GPR:$rs1, 0xFFFFFFFF), uimm5:$shamt)),
def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)),
(SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask),
Shifted32OnesMask:$mask)>;

def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFF), GPR:$rs2)),
(ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>;

def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (or_is_add_non_imm12 (and GPR:$rs1, 0xFFFFFFFF), GPR:$rs2)),
(ADD_UW GPR:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), non_imm12:$rs2)),
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), non_imm12:$rs2)),
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), non_imm12:$rs2)),
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>;
foreach i = {1,2,3} in {
defvar shxadd_uw = !cast<Instruction>("SH"#i#"ADD_UW");
def : Pat<(i64 (add_non_imm12 (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 i)), (XLenVT GPR:$rs2))),
(shxadd_uw GPR:$rs1, GPR:$rs2)>;
}

def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (XLenVT GPR:$rs2))),
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (XLenVT GPR:$rs2))),
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (XLenVT GPR:$rs2))),
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>;

// More complex cases use a ComplexPattern.
Expand All @@ -760,19 +783,19 @@ def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
(SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFE), (XLenVT GPR:$rs2))),
(SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFFC), (XLenVT GPR:$rs2))),
(SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0xFFFFFFF8), (XLenVT GPR:$rs2))),
(SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>;

// Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x1FFFFFFFE), (XLenVT GPR:$rs2))),
(SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x3FFFFFFFC), (XLenVT GPR:$rs2))),
(SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)),
def : Pat<(i64 (add_non_imm12 (and GPR:$rs1, 0x7FFFFFFF8), (XLenVT GPR:$rs2))),
(SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>;

def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
Expand Down
Loading

0 comments on commit 5f5faf4

Please sign in to comment.