diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index e4d5e5c71b7e18..e9401d4b93c371 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/CodeGen/Analysis.h" +#include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/MachineValueType.h" @@ -672,7 +673,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // We have some custom DAG combine patterns for these nodes setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL, - ISD::SREM, ISD::UREM}); + ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -5252,6 +5253,47 @@ static SDValue PerformSETCCCombine(SDNode *N, CCNode.getValue(1)); } +static SDValue PerformEXTRACTCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue Vector = N->getOperand(0); + EVT VectorVT = Vector.getValueType(); + if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() && + IsPTXVectorType(VectorVT.getSimpleVT())) + return SDValue(); // Native vector loads already combine nicely w/ + // extract_vector_elt. + // Don't mess with singletons or v2*16 types, we already handle them OK. + if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT)) + return SDValue(); + + uint64_t VectorBits = VectorVT.getSizeInBits(); + // We only handle the types we can extract in-register. + if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64)) + return SDValue(); + + ConstantSDNode *Index = dyn_cast(N->getOperand(1)); + // Index == 0 is handled by generic DAG combiner. + if (!Index || Index->getZExtValue() == 0) + return SDValue(); + + SDLoc DL(N); + + MVT IVT = MVT::getIntegerVT(VectorBits); + EVT EltVT = VectorVT.getVectorElementType(); + EVT EltIVT = EltVT.changeTypeToInteger(); + uint64_t EltBits = EltVT.getScalarSizeInBits(); + + SDValue Result = DCI.DAG.getNode( + ISD::TRUNCATE, DL, EltIVT, + DCI.DAG.getNode( + ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector), + DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT))); + + // If element has non-integer type, bitcast it back to the expected type. + if (EltVT != EltIVT) + Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result); + return Result; +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); @@ -5275,6 +5317,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case NVPTXISD::StoreRetvalV2: case NVPTXISD::StoreRetvalV4: return PerformStoreRetvalCombine(N); + case ISD::EXTRACT_VECTOR_ELT: + return PerformEXTRACTCombine(N, DCI); } return SDValue(); } diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 3e48c0f9d2c6ab..ad10d7938ef12e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1713,34 +1713,56 @@ def FUNSHFRCLAMP : // BFE - bit-field extract // -// Template for BFE instructions. Takes four args, -// [dest (reg), src (reg), start (reg or imm), end (reg or imm)]. +// Template for BFE/BFI instructions. +// Args: [dest (reg), src (reg), start (reg or imm), end (reg or imm)]. // Start may be an imm only if end is also an imm. FIXME: Is this a // restriction in PTX? // // dest and src may be int32 or int64, but start and end are always int32. -multiclass BFE { +multiclass BFX { def rrr : NVPTXInst<(outs RC:$d), (ins RC:$a, Int32Regs:$b, Int32Regs:$c), - !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>; + !strconcat(Instr, " \t$d, $a, $b, $c;"), []>; def rri : NVPTXInst<(outs RC:$d), (ins RC:$a, Int32Regs:$b, i32imm:$c), - !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>; + !strconcat(Instr, " \t$d, $a, $b, $c;"), []>; def rii : NVPTXInst<(outs RC:$d), (ins RC:$a, i32imm:$b, i32imm:$c), - !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>; + !strconcat(Instr, " \t$d, $a, $b, $c;"), []>; } let hasSideEffects = false in { - defm BFE_S32 : BFE<"s32", Int32Regs>; - defm BFE_U32 : BFE<"u32", Int32Regs>; - defm BFE_S64 : BFE<"s64", Int64Regs>; - defm BFE_U64 : BFE<"u64", Int64Regs>; + defm BFE_S32 : BFX<"bfe.s32", Int32Regs>; + defm BFE_U32 : BFX<"bfe.u32", Int32Regs>; + defm BFE_S64 : BFX<"bfe.s64", Int64Regs>; + defm BFE_U64 : BFX<"bfe.u64", Int64Regs>; + + defm BFI_S32 : BFX<"bfi.s32", Int32Regs>; + defm BFI_U32 : BFX<"bfi.u32", Int32Regs>; + defm BFI_S64 : BFX<"bfi.s64", Int64Regs>; + defm BFI_U64 : BFX<"bfi.u64", Int64Regs>; } +// Common byte extraction patterns +def : Pat<(i16 (sext_inreg (trunc Int32Regs:$s), i8)), + (CVT_s8_s32 Int32Regs:$s, CvtNONE)>; +def : Pat<(i16 (sext_inreg (trunc (srl (i32 Int32Regs:$s), (i32 imm:$o))), i8)), + (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, imm:$o, 8), CvtNONE)>; +def : Pat<(sext_inreg (srl (i32 Int32Regs:$s), (i32 imm:$o)), i8), + (BFE_S32rii Int32Regs:$s, imm:$o, 8)>; +def : Pat<(i16 (sra (i16 (trunc Int32Regs:$s)), (i32 8))), + (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, 8, 8), CvtNONE)>; + +def : Pat<(sext_inreg (srl (i64 Int64Regs:$s), (i32 imm:$o)), i8), + (BFE_S64rii Int64Regs:$s, imm:$o, 8)>; +def : Pat<(i16 (sext_inreg (trunc Int64Regs:$s), i8)), + (CVT_s8_s64 Int64Regs:$s, CvtNONE)>; +def : Pat<(i16 (sext_inreg (trunc (srl (i64 Int64Regs:$s), (i32 imm:$o))), i8)), + (CVT_s8_s64 (BFE_S64rii Int64Regs:$s, imm:$o, 8), CvtNONE)>; + //----------------------------------- // Comparison instructions (setp, set) //----------------------------------- diff --git a/llvm/test/CodeGen/NVPTX/extractelement.ll b/llvm/test/CodeGen/NVPTX/extractelement.ll new file mode 100644 index 00000000000000..da07f973501c85 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/extractelement.ll @@ -0,0 +1,89 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_35 -verify-machineinstrs | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_35 | %ptxas-verify %} + + +; CHECK-LABEL: test_v2i8 +; CHECK-DAG: ld.param.u16 [[A:%rs[0-9+]]], [test_v2i8_param_0]; +; CHECK-DAG: cvt.s16.s8 [[E0:%rs[0-9+]]], [[A]]; +; CHECK-DAG: shr.s16 [[E1:%rs[0-9+]]], [[A]], 8; +define i16 @test_v2i8(i16 %a) { + %v = bitcast i16 %a to <2 x i8> + %r0 = extractelement <2 x i8> %v, i64 0 + %r1 = extractelement <2 x i8> %v, i64 1 + %r0i = sext i8 %r0 to i16 + %r1i = sext i8 %r1 to i16 + %r01 = add i16 %r0i, %r1i + ret i16 %r01 +} + +; CHECK-LABEL: test_v4i8 +; CHECK: ld.param.u32 [[R:%r[0-9+]]], [test_v4i8_param_0]; +; CHECK-DAG: cvt.s8.s32 [[E0:%rs[0-9+]]], [[R]]; +; CHECK-DAG: bfe.s32 [[R1:%r[0-9+]]], [[R]], 8, 8; +; CHECK-DAG: cvt.s8.s32 [[E1:%rs[0-9+]]], [[R1]]; +; CHECK-DAG: bfe.s32 [[R2:%r[0-9+]]], [[R]], 16, 8; +; CHECK-DAG: cvt.s8.s32 [[E2:%rs[0-9+]]], [[R2]]; +; CHECK-DAG: bfe.s32 [[R3:%r[0-9+]]], [[R]], 24, 8; +; CHECK-DAG: cvt.s8.s32 [[E3:%rs[0-9+]]], [[R3]]; +define i16 @test_v4i8(i32 %a) { + %v = bitcast i32 %a to <4 x i8> + %r0 = extractelement <4 x i8> %v, i64 0 + %r1 = extractelement <4 x i8> %v, i64 1 + %r2 = extractelement <4 x i8> %v, i64 2 + %r3 = extractelement <4 x i8> %v, i64 3 + %r0i = sext i8 %r0 to i16 + %r1i = sext i8 %r1 to i16 + %r2i = sext i8 %r2 to i16 + %r3i = sext i8 %r3 to i16 + %r01 = add i16 %r0i, %r1i + %r23 = add i16 %r2i, %r3i + %r = add i16 %r01, %r23 + ret i16 %r +} + +; CHECK-LABEL: test_v8i8 +; CHECK: ld.param.u64 [[R:%rd[0-9+]]], [test_v8i8_param_0]; +; CHECK-DAG: cvt.s8.s64 [[E0:%rs[0-9+]]], [[R]]; +; Element 1 is still extracted by trunc, shr 8, not sure why. +; CHECK-DAG: cvt.u16.u64 [[R01:%rs[0-9+]]], [[R]]; +; CHECK-DAG: shr.s16 [[E1:%rs[0-9+]]], [[R01]], 8; +; CHECK-DAG: bfe.s64 [[RD2:%rd[0-9+]]], [[R]], 16, 8; +; CHECK-DAG: cvt.s8.s64 [[E2:%rs[0-9+]]], [[RD2]]; +; CHECK-DAG: bfe.s64 [[RD3:%rd[0-9+]]], [[R]], 24, 8; +; CHECK-DAG: cvt.s8.s64 [[E3:%rs[0-9+]]], [[RD3]]; +; CHECK-DAG: bfe.s64 [[RD4:%rd[0-9+]]], [[R]], 32, 8; +; CHECK-DAG: cvt.s8.s64 [[E4:%rs[0-9+]]], [[RD4]]; +; CHECK-DAG: bfe.s64 [[RD5:%rd[0-9+]]], [[R]], 40, 8; +; CHECK-DAG: cvt.s8.s64 [[E5:%rs[0-9+]]], [[RD5]]; +; CHECK-DAG: bfe.s64 [[RD6:%rd[0-9+]]], [[R]], 48, 8; +; CHECK-DAG: cvt.s8.s64 [[E6:%rs[0-9+]]], [[RD6]]; +; CHECK-DAG: bfe.s64 [[RD7:%rd[0-9+]]], [[R]], 56, 8; +; CHECK-DAG: cvt.s8.s64 [[E7:%rs[0-9+]]], [[RD7]]; + +define i16 @test_v8i8(i64 %a) { + %v = bitcast i64 %a to <8 x i8> + %r0 = extractelement <8 x i8> %v, i64 0 + %r1 = extractelement <8 x i8> %v, i64 1 + %r2 = extractelement <8 x i8> %v, i64 2 + %r3 = extractelement <8 x i8> %v, i64 3 + %r4 = extractelement <8 x i8> %v, i64 4 + %r5 = extractelement <8 x i8> %v, i64 5 + %r6 = extractelement <8 x i8> %v, i64 6 + %r7 = extractelement <8 x i8> %v, i64 7 + %r0i = sext i8 %r0 to i16 + %r1i = sext i8 %r1 to i16 + %r2i = sext i8 %r2 to i16 + %r3i = sext i8 %r3 to i16 + %r4i = sext i8 %r4 to i16 + %r5i = sext i8 %r5 to i16 + %r6i = sext i8 %r6 to i16 + %r7i = sext i8 %r7 to i16 + %r01 = add i16 %r0i, %r1i + %r23 = add i16 %r2i, %r3i + %r45 = add i16 %r4i, %r5i + %r67 = add i16 %r6i, %r7i + %r0123 = add i16 %r01, %r23 + %r4567 = add i16 %r45, %r67 + %r = add i16 %r0123, %r4567 + ret i16 %r +}