Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Don't use stack memory when bitcasting to/from v2i8 #113928

Merged
merged 6 commits into from
Nov 1, 2024

Conversation

peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Oct 28, 2024

v2i8 is an unsupported type, so we hit the default legalization rules which perform the bitcast in stack memory and is very inefficient on GPU.

This adds a custom lowering where we pack v2i8 into i16 and from there use another bitcast node to reach the final desired type. And also the inverse unpacking i16 into v2i8.

`v2i8` is and unsupported type, so we hit the default legalization rules
which perform the bitcast in stack memory and is very inefficient on GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from there use
another bitcast node to reach the final desired type. And also the
inverse unpacking `i16` into `v2i8`.
@peterbell10 peterbell10 changed the title [NVPTX] Don't use stack memory when bitcasting to/from 2xi8 [NVPTX] Don't use stack memory when bitcasting to/from v2i8 Oct 28, 2024
Copy link

github-actions bot commented Oct 28, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@peterbell10 peterbell10 marked this pull request as ready for review October 29, 2024 00:00
@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-backend-nvptx

Author: None (peterbell10)

Changes

v2i8 is an unsupported type, so we hit the default legalization rules which perform the bitcast in stack memory and is very inefficient on GPU.

This adds a custom lowering where we pack v2i8 into i16 and from there use another bitcast node to reach the final desired type. And also the inverse unpacking i16 into v2i8.


Full diff: https://github.com/llvm/llvm-project/pull/113928.diff

3 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+50)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2)
  • (added) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+36)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a95cba586b8fc3..050fbcfbcd8165 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -551,6 +551,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+
+  // Custom conversions to/from v2i8.
+  setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
+
   // Only logical ops can be done on v4i8 directly, others must be done
   // elementwise.
   setOperationAction(
@@ -2311,6 +2315,47 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
+  // Handle bitcasting to/from v2i8 without hitting the default promotion
+  // strategy which goes through stack memory.
+  SDNode *Node = Op.getNode();
+  SDLoc dl(Node);
+
+  auto maybeBitcast = [&](EVT vt, SDValue val) {
+    if (val->getValueType(0) == vt) {
+      return val;
+    }
+    return DAG.getNode(ISD::BITCAST, dl, vt, val);
+  };
+
+  EVT VT = Op->getValueType(0);
+  EVT fromVT = Op->getOperand(0)->getValueType(0);
+
+  if (VT == MVT::v2i8) {
+    // Bitcast to i16 and unpack elements into a vector
+    SDValue reg = maybeBitcast(MVT::i16, Op->getOperand(0));
+    SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
+    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
+    SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
+                             DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
+    return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
+  } else if (fromVT == MVT::v2i8) {
+    // Pack vector elements into i16 and bitcast to final type
+    SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(0, dl));
+    SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(1, dl));
+    SDValue E0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v0);
+    SDValue E1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v1);
+    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
+    SDValue reg =
+        DAG.getNode(ISD::OR, dl, MVT::i16,
+                    {E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
+    return maybeBitcast(VT, reg);
+  }
+  return Op;
+}
+
 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
 // would get lowered as two constant loads and vector-packing move.
 // Instead we want just a constant move:
@@ -2818,6 +2863,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return Op;
   case ISD::BUILD_VECTOR:
     return LowerBUILD_VECTOR(Op, DAG);
+  case ISD::BITCAST:
+    return LowerBITCAST(Op, DAG);
   case ISD::EXTRACT_SUBVECTOR:
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -6413,6 +6460,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   switch (N->getOpcode()) {
   default:
     report_fatal_error("Unhandled custom legalization");
+  case ISD::BITCAST:
+    Results.push_back(LowerBITCAST(SDValue(N, 0), DAG));
+    return;
   case ISD::LOAD:
     ReplaceLoadVector(N, DAG, Results);
     return;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 824a659671967a..13153f4830b695 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
 
+  SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
new file mode 100644
index 00000000000000..2f5d8cfed2b7b7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -0,0 +1,36 @@
+; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
+; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN: | FileCheck  %s
+; RUN: %if ptxas %{                                                           \
+; RUN:   llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
+; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN:   | %ptxas-verify -arch=sm_90                                          \
+; RUN: %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; CHECK-LABEL: test_trunc_2xi8(
+; CHECK:      ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
+; CHECK:      mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
+; CHECK:      shl.b16 	[[RS3:%rs[0-9]+]], [[RS2]], 8;
+; CHECK:      and.b16  [[RS4:%rs[0-9]+]], [[RS1]], 255;
+; CHECK:      or.b16   [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
+; CHECK:      cvt.u32.u16  [[R2:%r[0-9]]], [[RS5]]
+; CHECK:      st.param.b32  [func_retval0], [[R2]];
+define i16 @test_trunc_2xi8(<2 x i16> %a) #0 {
+  %trunc = trunc <2 x i16> %a to <2 x i8>
+  %res = bitcast <2 x i8> %trunc to i16
+  ret i16 %res
+}
+
+; CHECK-LABEL: test_zext_2xi8(
+; CHECK:      ld.param.u16  [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
+; CHECK:      shr.u16 	[[RS2:%rs[0-9]+]], [[RS1]], 8;
+; CHECK:      mov.b32  [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
+; CHECK:      and.b32  [[R2:%r[0-9]+]], [[R1]], 16711935;
+; CHECK:      st.param.b32  [func_retval0], [[R2]];
+define <2 x i16> @test_zext_2xi8(i16 %a) #0 {
+  %vec = bitcast i16 %a to <2 x i8>
+  %ext = zext <2 x i8> %vec to <2 x i16>
+  ret <2 x i16> %ext
+}

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, thanks for this PR. The current lowering is definitely not what we want!


we hit the default legalization rules which perform the bitcast in stack memory

Do you know where this default legalization rule is implemented?


Note: mov vector-to-scalar (pack) or scalar-to-vector (unpack) doesn't support .b8.

@@ -0,0 +1,36 @@
; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're already checking for most of the PTX that are generated for each function, I'd recommend auto-generating the CHECK statements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't used that before but I gave it a shot and it didn't generate any checks at all for some reason, perhaps I was doing something wrong. Not sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to delete all the existing CHECK statements first.

If that doesn't work (and the rest of the MR looks good), I'll just submit a follow-up patch that auto-generates the test.

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Show resolved Hide resolved
llvm/test/CodeGen/NVPTX/i8x2-instructions.ll Outdated Show resolved Hide resolved
@peterbell10
Copy link
Contributor Author

we hit the default legalization rules which perform the bitcast in stack memory

Do you know where this default legalization rule is implemented?

For bitcast to v2i8 we hit

return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT,
CreateStackStoreLoad(InOp, OutVT));

And for bitcasts from v2i8 we hit

return CreateStackStoreLoad(N->getOperand(0), N->getValueType(0));

Copy link
Contributor

@justinfargnoli justinfargnoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending the resolution of all conversations!

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp Outdated Show resolved Hide resolved
@ThomasRaoux ThomasRaoux merged commit b74e588 into llvm:main Nov 1, 2024
8 checks passed
@peterbell10 peterbell10 deleted the nvptx-v2i8-no-stack branch November 1, 2024 15:03
peterbell10 added a commit to triton-lang/triton that referenced this pull request Nov 1, 2024
Update commit to include llvm/llvm-project#113928
ThomasRaoux pushed a commit to triton-lang/triton that referenced this pull request Nov 1, 2024
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
)

`v2i8` is an unsupported type, so we hit the default legalization rules
which perform the bitcast in stack memory and is very inefficient on
GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from
there use another bitcast node to reach the final desired type. And also
the inverse unpacking `i16` into `v2i8`.
smallp-o-p pushed a commit to smallp-o-p/llvm-project that referenced this pull request Nov 3, 2024
)

`v2i8` is an unsupported type, so we hit the default legalization rules
which perform the bitcast in stack memory and is very inefficient on
GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from
there use another bitcast node to reach the final desired type. And also
the inverse unpacking `i16` into `v2i8`.
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
)

`v2i8` is an unsupported type, so we hit the default legalization rules
which perform the bitcast in stack memory and is very inefficient on
GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from
there use another bitcast node to reach the final desired type. And also
the inverse unpacking `i16` into `v2i8`.
chsigg pushed a commit to openxla/triton that referenced this pull request Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants