Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISel] Add ISel supports for SHXADD from Zba extension #67863

Merged
merged 9 commits into from
Oct 18, 2023
130 changes: 130 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "RISCVTargetMachine.h"
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -55,6 +56,14 @@ class RISCVInstructionSelector : public InstructionSelector {

ComplexRendererFns selectShiftMask(MachineOperand &Root) const;

ComplexRendererFns selectNonImm12(MachineOperand &Root) const;

ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const;
template <unsigned ShAmt>
ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const {
return selectSHXADDOp(Root, ShAmt);
}

// Custom renderers for tablegen
void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI,
int OpIdx) const;
Expand Down Expand Up @@ -105,6 +114,127 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const {
return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
}

// This complex pattern actually serves as a perdicate that is effectively
mshockwave marked this conversation as resolved.
Show resolved Hide resolved
// `!isInt<12>(Imm)`.
InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectNonImm12(MachineOperand &Root) const {
MachineFunction &MF = *Root.getParent()->getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();

if (Root.isReg() && Root.getReg())
if (auto Val = getIConstantVRegValWithLookThrough(Root.getReg(), MRI)) {
// We do NOT want immediates that fit in 12 bits.
if (isInt<12>(Val->Value.getSExtValue()))
return std::nullopt;
}

return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}};
}

InstructionSelector::ComplexRendererFns
RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root,
unsigned ShAmt) const {
using namespace llvm::MIPatternMatch;
MachineFunction &MF = *Root.getParent()->getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();

if (!Root.isReg())
return std::nullopt;
Register RootReg = Root.getReg();

const unsigned XLen = STI.getXLen();
APInt Mask, C2;
Register RegY;
std::optional<bool> LeftShift;
// (and (shl y, c2), mask)
if (mi_match(RootReg, MRI,
m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
LeftShift = true;
// (and (lshr y, c2), mask)
else if (mi_match(RootReg, MRI,
m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask))))
LeftShift = false;

if (LeftShift.has_value()) {
if (*LeftShift)
Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue());
else
Mask &= maskTrailingOnes<uint64_t>(XLen - C2.getLimitedValue());

if (Mask.isShiftedMask()) {
unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();
// Given (and (shl y, c2), mask) in which mask has no leading zeros and
// c3 trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD.
if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLI, {DstReg}, {RegY})
.addImm(Trailing - C2.getLimitedValue());
MIB.addReg(DstReg);
}}};
}

// Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and c3
// trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD.
if (!*LeftShift && Leading == C2 && Trailing == ShAmt) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLI, {DstReg}, {RegY})
.addImm(Leading + Trailing);
MIB.addReg(DstReg);
}}};
}
}
}

LeftShift.reset();

// (shl (and y, mask), c2)
if (mi_match(RootReg, MRI,
m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
m_ICst(C2))))
LeftShift = true;
// (lshr (and y, mask), c2)
else if (mi_match(RootReg, MRI,
m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))),
m_ICst(C2))))
LeftShift = false;

if (LeftShift.has_value())
if (Mask.isShiftedMask()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this condition with the previous if?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this comment addressed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is addressed now

unsigned Leading = XLen - Mask.getActiveBits();
unsigned Trailing = Mask.countr_zero();

// Given (shl (and y, mask), c2) in which mask has 32 leading zeros and
// c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD.
bool Cond = *LeftShift && Leading == 32 && Trailing > 0 &&
(Trailing + C2.getLimitedValue()) == ShAmt;
if (!Cond)
// Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and
// c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD.
Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) &&
(Trailing - C2.getLimitedValue()) == ShAmt;

if (Cond) {
Register DstReg =
MRI.createGenericVirtualRegister(MRI.getType(RootReg));
return {{[=](MachineInstrBuilder &MIB) {
MachineIRBuilder(*MIB.getInstr())
.buildInstr(RISCV::SRLIW, {DstReg}, {RegY})
.addImm(Trailing);
MIB.addReg(DstReg);
}}};
}
}

return std::nullopt;
}

// Tablegen doesn't allow us to write SRLIW/SRAIW/SLLIW patterns because the
// immediate Operand has type XLenVT. GlobalISel wants it to be i32.
bool RISCVInstructionSelector::earlySelectShift(
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/RISCV/RISCVGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def ShiftMaskGI :
GIComplexOperandMatcher<s32, "selectShiftMask">,
GIComplexPatternEquiv<shiftMaskXLen>;

def gi_non_imm12 : GIComplexOperandMatcher<s32, "selectNonImm12">,
GIComplexPatternEquiv<non_imm12>;

def gi_sh1add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<1>">,
GIComplexPatternEquiv<sh1add_op>;
def gi_sh2add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<2>">,
GIComplexPatternEquiv<sh2add_op>;
def gi_sh3add_op : GIComplexOperandMatcher<s32, "selectSHXADDOp<3>">,
GIComplexPatternEquiv<sh3add_op>;

// FIXME: Canonicalize (sub X, C) -> (add X, -C) earlier.
def : Pat<(XLenVT (sub GPR:$rs1, simm12Plus1:$imm)),
(ADDI GPR:$rs1, (NegImm simm12Plus1:$imm))>;
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,15 @@ bool RISCVDAGToDAGISel::SelectAddrRegImm(SDValue Addr, SDValue &Base,
return true;
}

bool RISCVDAGToDAGISel::selectNonImm12(SDValue N, SDValue &Opnd) {
auto *C = dyn_cast<ConstantSDNode>(N);
if (!C || !isInt<12>(C->getSExtValue())) {
Opnd = N;
return true;
}
return false;
}

bool RISCVDAGToDAGISel::selectShiftMask(SDValue N, unsigned ShiftWidth,
SDValue &ShAmt) {
ShAmt = N;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
bool trySignedBitfieldExtract(SDNode *Node);
bool tryIndexedLoad(SDNode *Node);

bool selectNonImm12(SDValue N, SDValue &Opnd);

bool selectShiftMask(SDValue N, unsigned ShiftWidth, SDValue &ShAmt);
bool selectShiftMaskXLen(SDValue N, SDValue &ShAmt) {
return selectShiftMask(N, Subtarget->getXLen(), ShAmt);
Expand Down
51 changes: 24 additions & 27 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZb.td
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,7 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{
}]>;

// Pattern to exclude simm12 immediates from matching.
def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{
auto *C = dyn_cast<ConstantSDNode>(N);
return !C || !isInt<12>(C->getSExtValue());
}]>;
def non_imm12 : ComplexPattern<XLenVT, 1, "selectNonImm12", [], [], 0>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add let GISelPredicateCode here to do this without changing to a complex pattern?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some digging, I think the answer is no and I'm sad about it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, using GISelPredicateCode was my first approach until I found that the GlobalISelEmitter TG backend doesn't pick that up for leaf nodes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add GISelPredicateCode to something like this work topperc@01205c1

Copy link
Member Author

@mshockwave mshockwave Oct 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add GISelPredicateCode to something like this work topperc@01205c1

This works, though I'm a little concerned that this might create too many boilerplate code in the future, since there needs to be a Predicate TG record for every opcode that goes with non_imm12 (even we abstract the real predicate logics into a function). What do you think?

Also, interestingly GISelPredicateCode doesn't dance well with PredicateCodeUsesOperands: it SEGFAULT llvm-tblgen in our case, despite the fact that there are tests for this exact combination (in test/TableGen/GlobalISelEmitterCustomPredicate.td). I can't find an obvious fix for llvm-tblgen but writing a non-PredicateCodeUsesOperands predicate code works so I'm not too bothered by this crash (for now).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used PredicateCodeUsesOperands so that I could know which operand wasn't the shl since add is commutable and tblgen will generate both patterns.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add GISelPredicateCode to something like this work topperc@01205c1

Done: it's no longer using ComplexPattern of non_imm12 but PatFrag of <op>_with_non_imm12 instead


def Shifted32OnesMask : PatLeaf<(imm), [{
uint64_t Imm = N->getZExtValue();
Expand Down Expand Up @@ -651,19 +648,19 @@ let Predicates = [HasStdExtZbb, IsRV64] in
def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>;

let Predicates = [HasStdExtZba] in {
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), non_imm12:$rs2),
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), (non_imm12 (XLenVT GPR:$rs2))),
(SH1ADD GPR:$rs1, GPR:$rs2)>;
def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), non_imm12:$rs2),
def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), (non_imm12 (XLenVT GPR:$rs2))),
(SH2ADD GPR:$rs1, GPR:$rs2)>;
def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), non_imm12:$rs2),
def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), (non_imm12 (XLenVT GPR:$rs2))),
(SH3ADD GPR:$rs1, GPR:$rs2)>;

// More complex cases use a ComplexPattern.
def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2),
def : Pat<(add sh1add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
(SH1ADD sh1add_op:$rs1, GPR:$rs2)>;
def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
def : Pat<(add sh2add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
(SH2ADD sh2add_op:$rs1, GPR:$rs2)>;
def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
def : Pat<(add sh3add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))),
(SH3ADD sh3add_op:$rs1, GPR:$rs2)>;

def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
Expand Down Expand Up @@ -735,48 +732,48 @@ def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)),
(SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask),
Shifted32OnesMask:$mask)>;

def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
(ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>;

def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
(ADD_UW GPR:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), non_imm12:$rs2)),
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), (non_imm12 (XLenVT GPR:$rs2)))),
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), non_imm12:$rs2)),
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), (non_imm12 (XLenVT GPR:$rs2)))),
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), non_imm12:$rs2)),
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), (non_imm12 (XLenVT GPR:$rs2)))),
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)),
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))),
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>;

// More complex cases use a ComplexPattern.
def : Pat<(i64 (add sh1add_uw_op:$rs1, non_imm12:$rs2)),
def : Pat<(i64 (add sh1add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
(SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)),
def : Pat<(i64 (add sh2add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
(SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>;
def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)),
def : Pat<(i64 (add sh3add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))),
(SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>;

def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
(SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
(SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
(SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>;

// Use SRLI to clear the LSBs and SHXADD_UW to mask and shift.
def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))),
(SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))),
(SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>;
def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)),
def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))),
(SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>;

def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)),
Expand Down
Loading