Skip to content

Commit

Permalink
[CLANG]Add Neon vectors for mfloat8_t
Browse files Browse the repository at this point in the history
This patch adds these new vector sizes for neon:
mfloat8x16_t and mfloat8x8_t

According to the ARM ACLE PR#323[1].

[1] ARM-software/acle#323
  • Loading branch information
CarolineConcatto committed Jul 31, 2024
1 parent de397fb commit bfd90a7
Show file tree
Hide file tree
Showing 15 changed files with 240 additions and 17 deletions.
14 changes: 14 additions & 0 deletions clang/include/clang/Basic/arm_mfp8.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===--- arm_mfp8.td - ARM MFP8 compiler interface ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the TableGen definitions from which the ARM MFP8 header
// file will be generated.
//
//===----------------------------------------------------------------------===//

include "arm_neon_incl.td"
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/arm_neon_incl.td
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def OP_UNAVAILABLE : Operation {
// h: half-float
// d: double
// b: bfloat16
// m: mfloat8
//
// Typespec modifiers
// ------------------
Expand All @@ -240,6 +241,7 @@ def OP_UNAVAILABLE : Operation {
// B: change to BFloat16
// P: change to polynomial category.
// p: change polynomial to equivalent integer category. Otherwise nop.
// M: change to MFloat8.
//
// >: double element width (vector size unchanged).
// <: half element width (vector size unchanged).
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Basic/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ void AArch64TargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__ARM_FEATURE_BF16_SCALAR_ARITHMETIC", "1");
}

if (HasMFloat8) {
Builder.defineMacro("__ARM_FEATURE_FP8", "1");
}
if ((FPU & SveMode) && HasBFloat16) {
Builder.defineMacro("__ARM_FEATURE_SVE_BF16", "1");
}
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Basic/Targets/ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ bool ARMTargetInfo::hasBFloat16Type() const {
return HasBFloat16 || (FPU && !SoftFloat);
}

bool ARMTargetInfo::hasMFloat8Type() const { return true; }

bool ARMTargetInfo::isValidCPUName(StringRef Name) const {
return Name == "generic" ||
llvm::ARM::parseCPUArch(Name) != llvm::ARM::ArchKind::INVALID;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Basic/Targets/ARM.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class LLVM_LIBRARY_VISIBILITY ARMTargetInfo : public TargetInfo {

bool hasBFloat16Type() const override;

bool hasMFloat8Type() const override;

bool isValidCPUName(StringRef Name) const override;
void fillValidCPUList(SmallVectorImpl<StringRef> &Values) const override;

Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6230,6 +6230,8 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
case NeonTypeFlags::Int8:
case NeonTypeFlags::Poly8:
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
case NeonTypeFlags::MFloat8:
return llvm::FixedVectorType::get(CGF->MFloat8Ty, V1Ty ? 1 : (8 << IsQuad));
case NeonTypeFlags::Int16:
case NeonTypeFlags::Poly16:
return llvm::FixedVectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad));
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Headers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ if(ARM IN_LIST LLVM_TARGETS_TO_BUILD OR AArch64 IN_LIST LLVM_TARGETS_TO_BUILD)
clang_generate_header(-gen-arm-sme-header arm_sme.td arm_sme.h)
# Generate arm_bf16.h
clang_generate_header(-gen-arm-bf16 arm_bf16.td arm_bf16.h)
# Generate arm_mfp8.h
clang_generate_header(-gen-arm-mfp8 arm_mfp8.td arm_mfp8.h)
# Generate arm_mve.h
clang_generate_header(-gen-arm-mve-header arm_mve.td arm_mve.h)
# Generate arm_cde.h
Expand All @@ -414,6 +416,7 @@ if(ARM IN_LIST LLVM_TARGETS_TO_BUILD OR AArch64 IN_LIST LLVM_TARGETS_TO_BUILD)
"${CMAKE_CURRENT_BINARY_DIR}/arm_sme.h"
"${CMAKE_CURRENT_BINARY_DIR}/arm_bf16.h"
"${CMAKE_CURRENT_BINARY_DIR}/arm_vector_types.h"
"${CMAKE_CURRENT_BINARY_DIR}/arm_mfp8.h"
)
endif()
if(RISCV IN_LIST LLVM_TARGETS_TO_BUILD)
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Sema/SemaARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
return Context.DoubleTy;
case NeonTypeFlags::BFloat16:
return Context.BFloat16Ty;
case NeonTypeFlags::MFloat8:
return Context.MFloat8Ty;
}
llvm_unreachable("Invalid NeonTypeFlag!");
}
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/Sema/SemaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10215,6 +10215,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
const VectorType *RHSVecType = RHSType->getAs<VectorType>();
assert(LHSVecType || RHSVecType);

// Any operation with MFloat8 type is only possible with C intrinsics
if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
(RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
return InvalidOperands(Loc, LHS, RHS);

// AltiVec-style "vector bool op vector bool" combinations are allowed
// for some operators but not others.
if (!AllowBothBool && LHSVecType &&
Expand Down
85 changes: 71 additions & 14 deletions clang/test/CodeGen/arm-mfp8.c
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 4
// RUN: %clang_cc1 -emit-llvm -triple aarch64-arm-none-eabi -target-feature -fp8 -o - %s | FileCheck %s
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
// RUN: %clang_cc1 -emit-llvm -triple aarch64-arm-none-eabi -target-feature -fp8 -target-feature +neon -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-C
// RUN: %clang_cc1 -emit-llvm -triple aarch64-arm-none-eabi -target-feature -fp8 -target-feature +neon -o - -x c++ %s | FileCheck %s --check-prefixes=CHECK,CHECK-CXX

// REQUIRES: aarch64-registered-target

// CHECK-LABEL: define dso_local i8 @func1n(
// CHECK-SAME: i8 noundef [[MFP8:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: entry:
// CHECK-NEXT: [[MFP8_ADDR:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[F1N:%.*]] = alloca [10 x i8], align 1
// CHECK-NEXT: store i8 [[MFP8]], ptr [[MFP8_ADDR]], align 1
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[MFP8_ADDR]], align 1
// CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-NEXT: store i8 [[TMP0]], ptr [[ARRAYIDX]], align 1
// CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
// CHECK-NEXT: ret i8 [[TMP1]]
// CHECK-C-LABEL: define dso_local i8 @func1n(
// CHECK-C-SAME: i8 noundef [[MFP8:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-C-NEXT: [[ENTRY:.*:]]
// CHECK-C-NEXT: [[MFP8_ADDR:%.*]] = alloca i8, align 1
// CHECK-C-NEXT: [[F1N:%.*]] = alloca [10 x i8], align 1
// CHECK-C-NEXT: store i8 [[MFP8]], ptr [[MFP8_ADDR]], align 1
// CHECK-C-NEXT: [[TMP0:%.*]] = load i8, ptr [[MFP8_ADDR]], align 1
// CHECK-C-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-C-NEXT: store i8 [[TMP0]], ptr [[ARRAYIDX]], align 1
// CHECK-C-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-C-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
// CHECK-C-NEXT: ret i8 [[TMP1]]
//
// CHECK-CXX-LABEL: define dso_local noundef i8 @_Z6func1nw(
// CHECK-CXX-SAME: i8 noundef [[MFP8:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
// CHECK-CXX-NEXT: [[MFP8_ADDR:%.*]] = alloca i8, align 1
// CHECK-CXX-NEXT: [[F1N:%.*]] = alloca [10 x i8], align 1
// CHECK-CXX-NEXT: store i8 [[MFP8]], ptr [[MFP8_ADDR]], align 1
// CHECK-CXX-NEXT: [[TMP0:%.*]] = load i8, ptr [[MFP8_ADDR]], align 1
// CHECK-CXX-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-CXX-NEXT: store i8 [[TMP0]], ptr [[ARRAYIDX]], align 1
// CHECK-CXX-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds [10 x i8], ptr [[F1N]], i64 0, i64 2
// CHECK-CXX-NEXT: [[TMP1:%.*]] = load i8, ptr [[ARRAYIDX1]], align 1
// CHECK-CXX-NEXT: ret i8 [[TMP1]]
//
__mfp8 func1n(__mfp8 mfp8) {
__mfp8 f1n[10];
Expand All @@ -23,4 +37,47 @@ __mfp8 func1n(__mfp8 mfp8) {
}


#include <arm_neon.h>

// CHECK-C-LABEL: define dso_local <16 x i8> @test_ret_mfloat8x16_t(
// CHECK-C-SAME: <16 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-C-NEXT: [[ENTRY:.*:]]
// CHECK-C-NEXT: [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
// CHECK-C-NEXT: store <16 x i8> [[V]], ptr [[V_ADDR]], align 16
// CHECK-C-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
// CHECK-C-NEXT: ret <16 x i8> [[TMP0]]
//
// CHECK-CXX-LABEL: define dso_local noundef <16 x i8> @_Z21test_ret_mfloat8x16_t16__MFloat8_tx16_t(
// CHECK-CXX-SAME: <16 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
// CHECK-CXX-NEXT: store <16 x i8> [[V]], ptr [[V_ADDR]], align 16
// CHECK-CXX-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
// CHECK-CXX-NEXT: ret <16 x i8> [[TMP0]]
//
mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
return v;
}

// CHECK-C-LABEL: define dso_local <8 x i8> @test_ret_mfloat8x8_t(
// CHECK-C-SAME: <8 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-C-NEXT: [[ENTRY:.*:]]
// CHECK-C-NEXT: [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
// CHECK-C-NEXT: store <8 x i8> [[V]], ptr [[V_ADDR]], align 8
// CHECK-C-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
// CHECK-C-NEXT: ret <8 x i8> [[TMP0]]
//
// CHECK-CXX-LABEL: define dso_local noundef <8 x i8> @_Z20test_ret_mfloat8x8_t15__MFloat8_tx8_t(
// CHECK-CXX-SAME: <8 x i8> noundef [[V:%.*]]) #[[ATTR0]] {
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
// CHECK-CXX-NEXT: store <8 x i8> [[V]], ptr [[V_ADDR]], align 8
// CHECK-CXX-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
// CHECK-CXX-NEXT: ret <8 x i8> [[TMP0]]
//
mfloat8x8_t test_ret_mfloat8x8_t(mfloat8x8_t v) {
return v;
}

//// NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
// CHECK: {{.*}}
53 changes: 53 additions & 0 deletions clang/test/Sema/arm-fpm8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: %clang_cc1 -fsyntax-only -verify=scalar,neon -triple aarch64-arm-none-eabi \
// RUN: -target-feature -fp8 -target-feature +neon %s

// REQUIRES: aarch64-registered-target
__fpm8 test_static_cast_from_char(char in) {
return static_cast<__fpm8>(in); // scalar-error {{static_cast from 'char' to '__fpm8' is not allowed}}
}

char test_static_cast_to_char(__fpm8 in) {
return static_cast<char>(in); // scalar-error {{static_cast from '__fpm8' to 'char' is not allowed}}
}
void test(bool b) {
__fpm8 fpm8;

fpm8 + fpm8; // scalar-error {{invalid operands to binary expression ('__fpm8' and '__fpm8')}}
fpm8 - fpm8; // scalar-error {{invalid operands to binary expression ('__fpm8' and '__fpm8')}}
fpm8 * fpm8; // scalar-error {{invalid operands to binary expression ('__fpm8' and '__fpm8')}}
fpm8 / fpm8; // scalar-error {{invalid operands to binary expression ('__fpm8' and '__fpm8')}}
++fpm8; // scalar-error {{cannot increment value of type '__fpm8'}}
--fpm8; // scalar-error {{cannot decrement value of type '__fpm8'}}

char u8;

fpm8 + u8; // scalar-error {{invalid operands to binary expression ('__fpm8' and 'char')}}
u8 + fpm8; // scalar-error {{invalid operands to binary expression ('char' and '__fpm8')}}
fpm8 - u8; // scalar-error {{invalid operands to binary expression ('__fpm8' and 'char')}}
u8 - fpm8; // scalar-error {{invalid operands to binary expression ('char' and '__fpm8')}}
fpm8 * u8; // scalar-error {{invalid operands to binary expression ('__fpm8' and 'char')}}
u8 * fpm8; // scalar-error {{invalid operands to binary expression ('char' and '__fpm8')}}
fpm8 / u8; // scalar-error {{invalid operands to binary expression ('__fpm8' and 'char')}}
u8 / fpm8; // scalar-error {{invalid operands to binary expression ('char' and '__fpm8')}}
fpm8 = u8; // scalar-error {{assigning to '__fpm8' from incompatible type 'char'}}
u8 = fpm8; // scalar-error {{assigning to 'char' from incompatible type '__fpm8'}}
fpm8 + (b ? u8 : fpm8); // scalar-error {{incompatible operand types ('char' and '__fpm8')}}
}

#include <arm_neon.h>

void test_vector(fpm8x8_t a, fpm8x8_t b, uint8x8_t c) {
a + b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}
a - b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}
a * b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}
a / b; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'fpm8x8_t')}}

a + c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a - c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a * c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a / c; // neon-error {{invalid operands to binary expression ('fpm8x8_t' (vector of 8 'fpm8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
c + b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
c - b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
c * b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
c / b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'fpm8x8_t' (vector of 8 'fpm8_t' values))}}
}
20 changes: 19 additions & 1 deletion clang/test/Sema/arm-mfp8.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: %clang_cc1 -fsyntax-only -verify=scalar -triple aarch64-arm-none-eabi -target-feature -fp8 %s
// RUN: %clang_cc1 -fsyntax-only -verify=scalar,neon -triple aarch64-arm-none-eabi \
// RUN: -target-feature -fp8 -target-feature +neon %s

// REQUIRES: aarch64-registered-target
__mfp8 test_static_cast_from_char(char in) {
Expand Down Expand Up @@ -33,3 +34,20 @@ void test(bool b) {
mfp8 + (b ? u8 : mfp8); // scalar-error {{incompatible operand types ('char' and '__mfp8')}}
}

#include <arm_neon.h>

void test_vector(mfloat8x8_t a, mfloat8x8_t b, uint8x8_t c) {
a + b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x8_t')}}
a - b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x8_t')}}
a * b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x8_t')}}
a / b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x8_t')}}

a + c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a - c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a * c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
a / c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
c + b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x8_t' (vector of 8 'mfloat8_t' values))}}
c - b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x8_t' (vector of 8 'mfloat8_t' values))}}
c * b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x8_t' (vector of 8 'mfloat8_t' values))}}
c / b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x8_t' (vector of 8 'mfloat8_t' values))}}
}
Loading

0 comments on commit bfd90a7

Please sign in to comment.