Skip to content

Commit

Permalink
[X86] Fix miscompile in combineShiftRightArithmetic
Browse files Browse the repository at this point in the history
When folding (ashr (shl, x, c1), c2) we need to treat c1 and c2
as unsigned to find out if the combined shift should be a left
or right shift.
Also do an early out during pre-legalization in case c1 and c2
has different types, as that otherwise complicated the comparison
of c1 and c2 a bit.

(cherry picked from commit 3e6e54e)
  • Loading branch information
AZero13 authored and tstellar committed Apr 24, 2024
1 parent 76cbd41 commit 111ae45
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
29 changes: 16 additions & 13 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47035,10 +47035,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
return V;

// fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
// into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
// into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
// depending on sign of (SarConst - [56,48,32,24,16])
// fold (SRA (SHL X, ShlConst), SraConst)
// into (SHL (sext_in_reg X), ShlConst - SraConst)
// or (sext_in_reg X)
// or (SRA (sext_in_reg X), SraConst - ShlConst)
// depending on relation between SraConst and ShlConst.
// We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
// us to do the sext_in_reg from corresponding bit.

// sexts in X86 are MOVs. The MOVs have the same code size
// as above SHIFTs (only SHIFT on 1 has lower code size).
Expand All @@ -47054,29 +47057,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
APInt ShlConst = N01->getAsAPIntVal();
APInt SarConst = N1->getAsAPIntVal();
APInt SraConst = N1->getAsAPIntVal();
EVT CVT = N1.getValueType();

if (SarConst.isNegative())
if (CVT != N01.getValueType())
return SDValue();
if (SraConst.isNegative())
return SDValue();

for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
unsigned ShiftSize = SVT.getSizeInBits();
// skipping types without corresponding sext/zext and
// ShlConst that is not one of [56,48,32,24,16]
// Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
continue;
SDLoc DL(N);
SDValue NN =
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
SarConst = SarConst - (Size - ShiftSize);
if (SarConst == 0)
if (SraConst.eq(ShlConst))
return NN;
if (SarConst.isNegative())
if (SraConst.ult(ShlConst))
return DAG.getNode(ISD::SHL, DL, VT, NN,
DAG.getConstant(-SarConst, DL, CVT));
DAG.getConstant(ShlConst - SraConst, DL, CVT));
return DAG.getNode(ISD::SRA, DL, VT, NN,
DAG.getConstant(SarConst, DL, CVT));
DAG.getConstant(SraConst - ShlConst, DL, CVT));
}
return SDValue();
}
Expand Down
10 changes: 5 additions & 5 deletions llvm/test/CodeGen/X86/sar_fold.ll
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ define void @shl144sar2(ptr %p) #0 {
; CHECK: # %bb.0:
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %eax
; CHECK-NEXT: movswl (%eax), %ecx
; CHECK-NEXT: sarl $31, %ecx
; CHECK-NEXT: shll $14, %ecx
; CHECK-NEXT: movl %ecx, 16(%eax)
; CHECK-NEXT: movl %ecx, 8(%eax)
; CHECK-NEXT: movl %ecx, 12(%eax)
; CHECK-NEXT: movl %ecx, 4(%eax)
; CHECK-NEXT: movl %ecx, (%eax)
; CHECK-NEXT: movl $0, 8(%eax)
; CHECK-NEXT: movl $0, 12(%eax)
; CHECK-NEXT: movl $0, 4(%eax)
; CHECK-NEXT: movl $0, (%eax)
; CHECK-NEXT: retl
%a = load i160, ptr %p
%1 = shl i160 %a, 144
Expand Down

0 comments on commit 111ae45

Please sign in to comment.