From 18a1e7ba9df032a95ce9602a6e7839a55564fe33 Mon Sep 17 00:00:00 2001 From: SpencerAbson Date: Sun, 8 Dec 2024 19:34:01 +0000 Subject: [PATCH] [AArch64] Implement intrinsics for SME FP8 F1CVT/F2CVT and BF1CVT/BF2CVT (#118027) This patch implements the following intrinsics: 8-bit floating-point convert to half-precision or BFloat16 (in-order). ``` c // Variant is also available for: _bf16[_mf8]_x2 svfloat16x2_t svcvt1_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming; svfloat16x2_t svcvt2_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming; ``` In accordance with https://github.com/ARM-software/acle/pull/323. Co-authored-by: Marin Lukac marian.lukac@arm.com Co-authored-by: Caroline Concatto caroline.concatto@arm.com --- clang/include/clang/Basic/arm_sve.td | 4 ++ .../fp8-intrinsics/acle_sme2_fp8_cvt.c | 64 +++++++++++++++++++ .../acle_sme2_fp8_cvt.c | 9 +++ llvm/include/llvm/IR/IntrinsicsAArch64.td | 32 +++++++--- .../Target/AArch64/AArch64ISelDAGToDAG.cpp | 12 ++++ .../AArch64/sme2-fp8-intrinsics-cvt.ll | 40 ++++++++++++ 6 files changed, 151 insertions(+), 10 deletions(-) diff --git a/clang/include/clang/Basic/arm_sve.td b/clang/include/clang/Basic/arm_sve.td index e551d6e46b8f33f..9b8a8546b072c01 100644 --- a/clang/include/clang/Basic/arm_sve.td +++ b/clang/include/clang/Basic/arm_sve.td @@ -2429,6 +2429,10 @@ let SVETargetGuard = InvalidMode, SMETargetGuard = "sme2,fp8" in { def FSCALE_X2 : Inst<"svscale[_{d}_x2]", "222.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x2", [IsStreaming],[]>; def FSCALE_X4 : Inst<"svscale[_{d}_x4]", "444.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x4", [IsStreaming],[]>; + // Convert from FP8 to half-precision/BFloat16 multi-vector + def SVF1CVT : Inst<"svcvt1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt1_x2", [IsStreaming, SetsFPMR], []>; + def SVF2CVT : Inst<"svcvt2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt2_x2", [IsStreaming, SetsFPMR], []>; + // Convert from FP8 to deinterleaved half-precision/BFloat16 multi-vector def SVF1CVTL : Inst<"svcvtl1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl1_x2", [IsStreaming, SetsFPMR], []>; def SVF2CVTL : Inst<"svcvtl2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl2_x2", [IsStreaming, SetsFPMR], []>; diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c index 5ba76671ff5d5b7..13609f034da336d 100644 --- a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c +++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c @@ -16,6 +16,70 @@ #define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3 #endif +// CHECK-LABEL: @test_cvt1_f16_x2( +// CHECK-NEXT: entry: +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16( [[ZN:%.*]]) +// CHECK-NEXT: ret { , } [[TMP0]] +// +// CPP-CHECK-LABEL: @_Z16test_cvt1_f16_x2u13__SVMfloat8_tm( +// CPP-CHECK-NEXT: entry: +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16( [[ZN:%.*]]) +// CPP-CHECK-NEXT: ret { , } [[TMP0]] +// +svfloat16x2_t test_cvt1_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming { + return SVE_ACLE_FUNC(svcvt1_f16,_mf8,_x2_fpm)(zn, fpmr); +} + +// CHECK-LABEL: @test_cvt2_f16_x2( +// CHECK-NEXT: entry: +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16( [[ZN:%.*]]) +// CHECK-NEXT: ret { , } [[TMP0]] +// +// CPP-CHECK-LABEL: @_Z16test_cvt2_f16_x2u13__SVMfloat8_tm( +// CPP-CHECK-NEXT: entry: +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16( [[ZN:%.*]]) +// CPP-CHECK-NEXT: ret { , } [[TMP0]] +// +svfloat16x2_t test_cvt2_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming { + return SVE_ACLE_FUNC(svcvt2_f16,_mf8,_x2_fpm)(zn, fpmr); +} + +// CHECK-LABEL: @test_cvt1_bf16_x2( +// CHECK-NEXT: entry: +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16( [[ZN:%.*]]) +// CHECK-NEXT: ret { , } [[TMP0]] +// +// CPP-CHECK-LABEL: @_Z17test_cvt1_bf16_x2u13__SVMfloat8_tm( +// CPP-CHECK-NEXT: entry: +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16( [[ZN:%.*]]) +// CPP-CHECK-NEXT: ret { , } [[TMP0]] +// +svbfloat16x2_t test_cvt1_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming { + return SVE_ACLE_FUNC(svcvt1_bf16,_mf8,_x2_fpm)(zn, fpmr); +} + +// CHECK-LABEL: @test_cvt2_bf16_x2( +// CHECK-NEXT: entry: +// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16( [[ZN:%.*]]) +// CHECK-NEXT: ret { , } [[TMP0]] +// +// CPP-CHECK-LABEL: @_Z17test_cvt2_bf16_x2u13__SVMfloat8_tm( +// CPP-CHECK-NEXT: entry: +// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) +// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { , } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16( [[ZN:%.*]]) +// CPP-CHECK-NEXT: ret { , } [[TMP0]] +// +svbfloat16x2_t test_cvt2_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming { + return SVE_ACLE_FUNC(svcvt2_bf16,_mf8,_x2_fpm)(zn, fpmr); +} + // CHECK-LABEL: @test_cvtl1_f16_x2( // CHECK-NEXT: entry: // CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]]) diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c index 09a80c9dff03eae..af1ef46ea69722f 100644 --- a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c +++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c @@ -14,4 +14,13 @@ void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming { svcvtl1_bf16_mf8_x2_fpm(zn, fpmr); // expected-error@+1 {{'svcvtl2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}} svcvtl2_bf16_mf8_x2_fpm(zn, fpmr); + + // expected-error@+1 {{'svcvt1_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}} + svcvt1_f16_mf8_x2_fpm(zn, fpmr); + // expected-error@+1 {{'svcvt2_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}} + svcvt2_f16_mf8_x2_fpm(zn, fpmr); + // expected-error@+1 {{'svcvt1_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}} + svcvt1_bf16_mf8_x2_fpm(zn, fpmr); + // expected-error@+1 {{'svcvt2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}} + svcvt2_bf16_mf8_x2_fpm(zn, fpmr); } \ No newline at end of file diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td index a91616b9556828e..8ca00fc59a25541 100644 --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -3812,16 +3812,6 @@ let TargetPrefix = "aarch64" in { [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>], [IntrNoMem]>; - - class SME2_FP8_CVT_X2_Single_Intrinsic - : DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>], - [llvm_nxv16i8_ty], - [IntrReadMem, IntrInaccessibleMemOnly]>; - // - // CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector - // - def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic; - def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic; } // SVE2.1 - ZIPQ1, ZIPQ2, UZPQ1, UZPQ2 @@ -3864,3 +3854,25 @@ def int_aarch64_sve_famin_u : AdvSIMD_Pred2VectorArg_Intrinsic; // Neon absolute maximum and minimum def int_aarch64_neon_famax : AdvSIMD_2VectorArg_Intrinsic; def int_aarch64_neon_famin : AdvSIMD_2VectorArg_Intrinsic; + +// +// FP8 Intrinsics +// +let TargetPrefix = "aarch64" in { + + class SME2_FP8_CVT_X2_Single_Intrinsic + : DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>], + [llvm_nxv16i8_ty], + [IntrReadMem, IntrInaccessibleMemOnly]>; + // + // CVT from FP8 to half-precision/BFloat16 multi-vector + // + def int_aarch64_sve_fp8_cvt1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic; + def int_aarch64_sve_fp8_cvt2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic; + + // + // CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector + // + def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic; + def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic; +} \ No newline at end of file diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index 5f0c3d2c21f7913..5df61b372203739 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -5581,6 +5581,18 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { {AArch64::BF2CVTL_2ZZ_BtoH, AArch64::F2CVTL_2ZZ_BtoH})) SelectCVTIntrinsicFP8(Node, 2, Opc); return; + case Intrinsic::aarch64_sve_fp8_cvt1_x2: + if (auto Opc = SelectOpcodeFromVT( + Node->getValueType(0), + {AArch64::BF1CVT_2ZZ_BtoH, AArch64::F1CVT_2ZZ_BtoH})) + SelectCVTIntrinsicFP8(Node, 2, Opc); + return; + case Intrinsic::aarch64_sve_fp8_cvt2_x2: + if (auto Opc = SelectOpcodeFromVT( + Node->getValueType(0), + {AArch64::BF2CVT_2ZZ_BtoH, AArch64::F2CVT_2ZZ_BtoH})) + SelectCVTIntrinsicFP8(Node, 2, Opc); + return; } } break; case ISD::INTRINSIC_WO_CHAIN: { diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll index 076a3ad34eac3c0..3d3fcb05f6cf072 100644 --- a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll +++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll @@ -1,6 +1,46 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s +; F1CVT / F2CVT + +define { , } @f1cvt( %zm) { +; CHECK-LABEL: f1cvt: +; CHECK: // %bb.0: +; CHECK-NEXT: f1cvt { z0.h, z1.h }, z0.b +; CHECK-NEXT: ret + %res = call { , } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16( %zm) + ret { , } %res +} + +define { , } @f2cvt( %zm) { +; CHECK-LABEL: f2cvt: +; CHECK: // %bb.0: +; CHECK-NEXT: f2cvt { z0.h, z1.h }, z0.b +; CHECK-NEXT: ret + %res = call { , } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16( %zm) + ret { , } %res +} + +; BF1CVT / BF2CVT + +define { , } @bf1cvt( %zm) { +; CHECK-LABEL: bf1cvt: +; CHECK: // %bb.0: +; CHECK-NEXT: bf1cvt { z0.h, z1.h }, z0.b +; CHECK-NEXT: ret + %res = call { , } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16( %zm) + ret { , } %res +} + +define { , } @bf2cvt( %zm) { +; CHECK-LABEL: bf2cvt: +; CHECK: // %bb.0: +; CHECK-NEXT: bf2cvt { z0.h, z1.h }, z0.b +; CHECK-NEXT: ret + %res = call { , } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16( %zm) + ret { , } %res +} + ; F1CVTL / F2CVTL define { , } @f1cvtl( %zm) {