Skip to content

Commit

Permalink
[AArch64] Implement intrinsics for SME FP8 F1CVT/F2CVT and BF1CVT/BF2…
Browse files Browse the repository at this point in the history
…CVT (#118027)

This patch implements the following intrinsics:

8-bit floating-point convert to half-precision or BFloat16 (in-order).
``` c
  // Variant is also available for: _bf16[_mf8]_x2
  svfloat16x2_t svcvt1_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming;
  svfloat16x2_t svcvt2_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming;
```

In accordance with ARM-software/acle#323.

Co-authored-by: Marin Lukac [email protected]
Co-authored-by: Caroline Concatto [email protected]
  • Loading branch information
SpencerAbson authored Dec 8, 2024
1 parent 2ab687e commit b0f0676
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 10 deletions.
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/arm_sve.td
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,10 @@ let SVETargetGuard = InvalidMode, SMETargetGuard = "sme2,fp8" in {
def FSCALE_X2 : Inst<"svscale[_{d}_x2]", "222.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x2", [IsStreaming],[]>;
def FSCALE_X4 : Inst<"svscale[_{d}_x4]", "444.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x4", [IsStreaming],[]>;

// Convert from FP8 to half-precision/BFloat16 multi-vector
def SVF1CVT : Inst<"svcvt1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt1_x2", [IsStreaming, SetsFPMR], []>;
def SVF2CVT : Inst<"svcvt2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt2_x2", [IsStreaming, SetsFPMR], []>;

// Convert from FP8 to deinterleaved half-precision/BFloat16 multi-vector
def SVF1CVTL : Inst<"svcvtl1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl1_x2", [IsStreaming, SetsFPMR], []>;
def SVF2CVTL : Inst<"svcvtl2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl2_x2", [IsStreaming, SetsFPMR], []>;
Expand Down
64 changes: 64 additions & 0 deletions clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,70 @@
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
#endif

// CHECK-LABEL: @test_cvt1_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z16test_cvt1_f16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
svfloat16x2_t test_cvt1_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt1_f16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt2_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z16test_cvt2_f16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
//
svfloat16x2_t test_cvt2_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt2_f16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt1_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z17test_cvt1_bf16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
svbfloat16x2_t test_cvt1_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt1_bf16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvt2_bf16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
// CPP-CHECK-LABEL: @_Z17test_cvt2_bf16_x2u13__SVMfloat8_tm(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
//
svbfloat16x2_t test_cvt2_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
return SVE_ACLE_FUNC(svcvt2_bf16,_mf8,_x2_fpm)(zn, fpmr);
}

// CHECK-LABEL: @test_cvtl1_f16_x2(
// CHECK-NEXT: entry:
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
Expand Down
9 changes: 9 additions & 0 deletions clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
svcvtl1_bf16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvtl2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvtl2_bf16_mf8_x2_fpm(zn, fpmr);

// expected-error@+1 {{'svcvt1_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt1_f16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt2_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt2_f16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt1_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt1_bf16_mf8_x2_fpm(zn, fpmr);
// expected-error@+1 {{'svcvt2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
svcvt2_bf16_mf8_x2_fpm(zn, fpmr);
}
32 changes: 22 additions & 10 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -3812,16 +3812,6 @@ let TargetPrefix = "aarch64" in {
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>,
LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>],
[IntrNoMem]>;

class SME2_FP8_CVT_X2_Single_Intrinsic
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;
//
// CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
//
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
}

// SVE2.1 - ZIPQ1, ZIPQ2, UZPQ1, UZPQ2
Expand Down Expand Up @@ -3864,3 +3854,25 @@ def int_aarch64_sve_famin_u : AdvSIMD_Pred2VectorArg_Intrinsic;
// Neon absolute maximum and minimum
def int_aarch64_neon_famax : AdvSIMD_2VectorArg_Intrinsic;
def int_aarch64_neon_famin : AdvSIMD_2VectorArg_Intrinsic;

//
// FP8 Intrinsics
//
let TargetPrefix = "aarch64" in {

class SME2_FP8_CVT_X2_Single_Intrinsic
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
[llvm_nxv16i8_ty],
[IntrReadMem, IntrInaccessibleMemOnly]>;
//
// CVT from FP8 to half-precision/BFloat16 multi-vector
//
def int_aarch64_sve_fp8_cvt1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvt2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;

//
// CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
//
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
}
12 changes: 12 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5581,6 +5581,18 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
{AArch64::BF2CVTL_2ZZ_BtoH, AArch64::F2CVTL_2ZZ_BtoH}))
SelectCVTIntrinsicFP8(Node, 2, Opc);
return;
case Intrinsic::aarch64_sve_fp8_cvt1_x2:
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
Node->getValueType(0),
{AArch64::BF1CVT_2ZZ_BtoH, AArch64::F1CVT_2ZZ_BtoH}))
SelectCVTIntrinsicFP8(Node, 2, Opc);
return;
case Intrinsic::aarch64_sve_fp8_cvt2_x2:
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
Node->getValueType(0),
{AArch64::BF2CVT_2ZZ_BtoH, AArch64::F2CVT_2ZZ_BtoH}))
SelectCVTIntrinsicFP8(Node, 2, Opc);
return;
}
} break;
case ISD::INTRINSIC_WO_CHAIN: {
Expand Down
40 changes: 40 additions & 0 deletions llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s

; F1CVT / F2CVT

define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: f1cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: f1cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
}

define { <vscale x 8 x half>, <vscale x 8 x half> } @f2cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: f2cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: f2cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
}

; BF1CVT / BF2CVT

define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf1cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: bf1cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: bf1cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
}

define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf2cvt(<vscale x 16 x i8> %zm) {
; CHECK-LABEL: bf2cvt:
; CHECK: // %bb.0:
; CHECK-NEXT: bf2cvt { z0.h, z1.h }, z0.b
; CHECK-NEXT: ret
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> %zm)
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
}

; F1CVTL / F2CVTL

define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvtl(<vscale x 16 x i8> %zm) {
Expand Down

0 comments on commit b0f0676

Please sign in to comment.