Skip to content

Commit

Permalink
[NVPTX] Don't use stack memory when bitcasting to/from v2i8 (#113928)
Browse files Browse the repository at this point in the history
`v2i8` is an unsupported type, so we hit the default legalization rules
which perform the bitcast in stack memory and is very inefficient on
GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from
there use another bitcast node to reach the final desired type. And also
the inverse unpacking `i16` into `v2i8`.
  • Loading branch information
peterbell10 authored Nov 1, 2024
1 parent 58f525a commit b74e588
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
62 changes: 62 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
return VectorInfo;
}

static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
SDValue Value) {
if (Value->getValueType(0) == VT)
return Value;
return DAG.getNode(ISD::BITCAST, DL, VT, Value);
}

// NVPTXTargetLowering Constructor.
NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
const NVPTXSubtarget &STI)
Expand Down Expand Up @@ -551,6 +558,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);

// Custom conversions to/from v2i8.
setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);

// Only logical ops can be done on v4i8 directly, others must be done
// elementwise.
setOperationAction(
Expand Down Expand Up @@ -2309,6 +2320,30 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}

SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
// Handle bitcasting from v2i8 without hitting the default promotion
// strategy which goes through stack memory.
EVT FromVT = Op->getOperand(0)->getValueType(0);
if (FromVT != MVT::v2i8) {
return Op;
}

// Pack vector elements into i16 and bitcast to final type
SDLoc DL(Op);
SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
SDValue AsInt = DAG.getNode(
ISD::OR, DL, MVT::i16,
{Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
EVT ToVT = Op->getValueType(0);
return MaybeBitcast(DAG, DL, ToVT, AsInt);
}

// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
// would get lowered as two constant loads and vector-packing move.
// Instead we want just a constant move:
Expand Down Expand Up @@ -2817,6 +2852,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return Op;
case ISD::BUILD_VECTOR:
return LowerBUILD_VECTOR(Op, DAG);
case ISD::BITCAST:
return LowerBITCAST(Op, DAG);
case ISD::EXTRACT_SUBVECTOR:
return Op;
case ISD::EXTRACT_VECTOR_ELT:
Expand Down Expand Up @@ -6127,6 +6164,28 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return SDValue();
}

static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results) {
// Handle bitcasting to v2i8 without hitting the default promotion
// strategy which goes through stack memory.
SDValue Op(Node, 0);
EVT ToVT = Op->getValueType(0);
if (ToVT != MVT::v2i8) {
return;
}

// Bitcast to i16 and unpack elements into a vector
SDLoc DL(Node);
SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
SDValue Vec1 =
DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
Results.push_back(
DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
}

/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results) {
Expand Down Expand Up @@ -6412,6 +6471,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
switch (N->getOpcode()) {
default:
report_fatal_error("Unhandled custom legalization");
case ISD::BITCAST:
ReplaceBITCAST(N, DAG, Results);
return;
case ISD::LOAD:
ReplaceLoadVector(N, DAG, Results);
return;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering {
const NVPTXSubtarget &STI; // cache the subtarget here
SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;

SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
Expand Down
33 changes: 33 additions & 0 deletions llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
; RUN: | FileCheck %s
; RUN: %if ptxas %{ \
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
; RUN: -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_90 \
; RUN: %}

target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"

; CHECK-LABEL: test_bitcast_2xi8_i16(
; CHECK: ld.param.u32 %r1, [test_bitcast_2xi8_i16_param_0];
; CHECK: mov.b32 {%rs1, %rs2}, %r1;
; CHECK: shl.b16 %rs3, %rs2, 8;
; CHECK: and.b16 %rs4, %rs1, 255;
; CHECK: or.b16 %rs5, %rs4, %rs3;
; CHECK: cvt.u32.u16 %r2, %rs5;
; CHECK: st.param.b32 [func_retval0], %r2;
define i16 @test_bitcast_2xi8_i16(<2 x i8> %a) {
%res = bitcast <2 x i8> %a to i16
ret i16 %res
}

; CHECK-LABEL: test_bitcast_i16_2xi8(
; CHECK: ld.param.u16 %rs1, [test_bitcast_i16_2xi8_param_0];
; CHECK: shr.u16 %rs2, %rs1, 8;
; CHECK: mov.b32 %r1, {%rs1, %rs2};
; CHECK: st.param.b32 [func_retval0], %r1;
define <2 x i8> @test_bitcast_i16_2xi8(i16 %a) {
%res = bitcast i16 %a to <2 x i8>
ret <2 x i8> %res
}

0 comments on commit b74e588

Please sign in to comment.