Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[APFloat] Add support for f8E4M3 IEEE 754 type #97179

Merged
merged 1 commit into from
Jul 18, 2024

Conversation

apivovarov
Copy link
Member

@apivovarov apivovarov commented Jun 29, 2024

This PR adds f8E4M3 type to APFloat.

f8E4M3 type follows IEEE 754 convention

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 17 =6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

Related PRs:

  • PR-97118 [MLIR] Add f8E4M3 IEEE 754 type
  • jax-ml/ml_dtypes PR-161 Add float8_e4m3

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" llvm:support llvm:adt labels Jun 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 29, 2024

@llvm/pr-subscribers-llvm-adt
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-llvm-support

Author: Alexander Pivovarov (apivovarov)

Changes

This PR adds f8E4M3 type to llvm.

f8E4M3 type follows IEEE 754 convention

f8E4M3 (IEEE 754)
- Exponent: 4
- Mantissa: 3
- Exponent bias: 7
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

Related PRs:

  • PR-97118 Add f8E4M3 IEEE 754 type to mlir

Full diff: https://github.com/llvm/llvm-project/pull/97179.diff

5 Files Affected:

  • (modified) clang/include/clang/AST/Stmt.h (+3-3)
  • (modified) clang/lib/AST/MicrosoftMangle.cpp (+1)
  • (modified) llvm/include/llvm/ADT/APFloat.h (+6)
  • (modified) llvm/lib/Support/APFloat.cpp (+20)
  • (modified) llvm/unittests/ADT/APFloatTest.cpp (+38)
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index 9cd7a364cd3f1..4edb04502f83e 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -460,10 +460,10 @@ class alignas(void *) Stmt {
     unsigned : NumExprBits;
 
     static_assert(
-        llvm::APFloat::S_MaxSemantics < 16,
-        "Too many Semantics enum values to fit in bitfield of size 4");
+        llvm::APFloat::S_MaxSemantics < 32,
+        "Too many Semantics enum values to fit in bitfield of size 5");
     LLVM_PREFERRED_TYPE(llvm::APFloat::Semantics)
-    unsigned Semantics : 4; // Provides semantics for APFloat construction
+    unsigned Semantics : 5; // Provides semantics for APFloat construction
     LLVM_PREFERRED_TYPE(bool)
     unsigned IsExact : 1;
   };
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index 7f1e9ab02ec26..a2e270df276cb 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -946,6 +946,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
   case APFloat::S_IEEEquad: Out << 'Y'; break;
   case APFloat::S_PPCDoubleDouble: Out << 'Z'; break;
   case APFloat::S_Float8E5M2:
+  case APFloat::S_Float8E4M3:
   case APFloat::S_Float8E4M3FN:
   case APFloat::S_Float8E5M2FNUZ:
   case APFloat::S_Float8E4M3FNUZ:
diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h
index db2fa480655c6..bff8e6490d1de 100644
--- a/llvm/include/llvm/ADT/APFloat.h
+++ b/llvm/include/llvm/ADT/APFloat.h
@@ -166,6 +166,9 @@ struct APFloatBase {
     // This format's exponent bias is 16, instead of the 15 (2 ** (5 - 1) - 1)
     // that IEEE precedent would imply.
     S_Float8E5M2FNUZ,
+    // 8-bit floating point number following IEEE-754 conventions with bit
+    // layout S1E4M3.
+    S_Float8E4M3,
     // 8-bit floating point number mostly following IEEE-754 conventions with
     // bit layout S1E4M3 as described in https://arxiv.org/abs/2209.05433.
     // Unlike IEEE-754 types, there are no infinity values, and NaN is
@@ -217,6 +220,7 @@ struct APFloatBase {
   static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
   static const fltSemantics &Float8E5M2() LLVM_READNONE;
   static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE;
+  static const fltSemantics &Float8E4M3() LLVM_READNONE;
   static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
   static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
   static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
@@ -638,6 +642,7 @@ class IEEEFloat final : public APFloatBase {
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
   APInt convertFloat8E5M2APFloatToAPInt() const;
   APInt convertFloat8E5M2FNUZAPFloatToAPInt() const;
+  APInt convertFloat8E4M3APFloatToAPInt() const;
   APInt convertFloat8E4M3FNAPFloatToAPInt() const;
   APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
   APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
@@ -656,6 +661,7 @@ class IEEEFloat final : public APFloatBase {
   void initFromPPCDoubleDoubleAPInt(const APInt &api);
   void initFromFloat8E5M2APInt(const APInt &api);
   void initFromFloat8E5M2FNUZAPInt(const APInt &api);
+  void initFromFloat8E4M3APInt(const APInt &api);
   void initFromFloat8E4M3FNAPInt(const APInt &api);
   void initFromFloat8E4M3FNUZAPInt(const APInt &api);
   void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp
index 3017a9b976658..79e0094b243b2 100644
--- a/llvm/lib/Support/APFloat.cpp
+++ b/llvm/lib/Support/APFloat.cpp
@@ -136,6 +136,7 @@ static constexpr fltSemantics semIEEEquad = {16383, -16382, 113, 128};
 static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
 static constexpr fltSemantics semFloat8E5M2FNUZ = {
     15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
+static constexpr fltSemantics semFloat8E4M3 = {7, -6, 4, 8};
 static constexpr fltSemantics semFloat8E4M3FN = {
     8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
 static constexpr fltSemantics semFloat8E4M3FNUZ = {
@@ -208,6 +209,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
     return Float8E5M2();
   case S_Float8E5M2FNUZ:
     return Float8E5M2FNUZ();
+  case S_Float8E4M3:
+    return Float8E4M3();
   case S_Float8E4M3FN:
     return Float8E4M3FN();
   case S_Float8E4M3FNUZ:
@@ -246,6 +249,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     return S_Float8E5M2;
   else if (&Sem == &llvm::APFloat::Float8E5M2FNUZ())
     return S_Float8E5M2FNUZ;
+  else if (&Sem == &llvm::APFloat::Float8E4M3())
+    return S_Float8E4M3;
   else if (&Sem == &llvm::APFloat::Float8E4M3FN())
     return S_Float8E4M3FN;
   else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
@@ -276,6 +281,7 @@ const fltSemantics &APFloatBase::PPCDoubleDouble() {
 }
 const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
 const fltSemantics &APFloatBase::Float8E5M2FNUZ() { return semFloat8E5M2FNUZ; }
+const fltSemantics &APFloatBase::Float8E4M3() { return semFloat8E4M3; }
 const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; }
 const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
 const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
@@ -3617,6 +3623,11 @@ APInt IEEEFloat::convertFloat8E5M2FNUZAPFloatToAPInt() const {
   return convertIEEEFloatToAPInt<semFloat8E5M2FNUZ>();
 }
 
+APInt IEEEFloat::convertFloat8E4M3APFloatToAPInt() const {
+  assert(partCount() == 1);
+  return convertIEEEFloatToAPInt<semFloat8E4M3>();
+}
+
 APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const {
   assert(partCount() == 1);
   return convertIEEEFloatToAPInt<semFloat8E4M3FN>();
@@ -3681,6 +3692,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
   if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FNUZ)
     return convertFloat8E5M2FNUZAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3)
+    return convertFloat8E4M3APFloatToAPInt();
+
   if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN)
     return convertFloat8E4M3FNAPFloatToAPInt();
 
@@ -3902,6 +3916,10 @@ void IEEEFloat::initFromFloat8E5M2FNUZAPInt(const APInt &api) {
   initFromIEEEAPInt<semFloat8E5M2FNUZ>(api);
 }
 
+void IEEEFloat::initFromFloat8E4M3APInt(const APInt &api) {
+  initFromIEEEAPInt<semFloat8E4M3>(api);
+}
+
 void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) {
   initFromIEEEAPInt<semFloat8E4M3FN>(api);
 }
@@ -3951,6 +3969,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
     return initFromFloat8E5M2APInt(api);
   if (Sem == &semFloat8E5M2FNUZ)
     return initFromFloat8E5M2FNUZAPInt(api);
+  if (Sem == &semFloat8E4M3)
+    return initFromFloat8E4M3APInt(api);
   if (Sem == &semFloat8E4M3FN)
     return initFromFloat8E4M3FNAPInt(api);
   if (Sem == &semFloat8E4M3FNUZ)
diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp
index 86a25f4394e19..132e95bd45f77 100644
--- a/llvm/unittests/ADT/APFloatTest.cpp
+++ b/llvm/unittests/ADT/APFloatTest.cpp
@@ -2133,6 +2133,8 @@ TEST(APFloatTest, getZero) {
       {&APFloat::Float8E5M2(), true, true, {0x80ULL, 0}, 1},
       {&APFloat::Float8E5M2FNUZ(), false, false, {0, 0}, 1},
       {&APFloat::Float8E5M2FNUZ(), true, false, {0, 0}, 1},
+      {&APFloat::Float8E4M3(), false, true, {0, 0}, 1},
+      {&APFloat::Float8E4M3(), true, true, {0x80ULL, 0}, 1},
       {&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1},
       {&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1},
       {&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
@@ -6846,6 +6848,42 @@ TEST(APFloatTest, Float8E5M2ToFloat) {
   EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
 
+TEST(APFloatTest, Float8E4M3ToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+  APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+  APFloat One(APFloat::Float8E4M3(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float8E4M3(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false);
+  EXPECT_EQ(240.0F, PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true);
+  EXPECT_EQ(-240.0F, NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false);
+  EXPECT_EQ(0x1.p-6, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true);
+  EXPECT_EQ(-0x1.p-6, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToFloat());
+
+  APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3());
+  EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
+  APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true);
+  EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
+}
+
 TEST(APFloatTest, Float8E4M3FNToFloat) {
   APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FN());
   APFloat PosZeroToFloat(PosZero.convertToFloat());

@apivovarov
Copy link
Member Author

Thorsten, Matt, What you think about adding the following tests?

// E4M3 <-> E5M2
ConvertE4M3ToE5M2
ConvertE5M2ToE4M3

// E4M3 <-> E4M3FN
ConvertE4M3ToE4M3FN
ConvertE4M3FNToE4M3

@tschuett @arsenm

@arsenm
Copy link
Contributor

arsenm commented Jun 30, 2024

Title and description needs rewording. This isn't adding the type "to llvm" which would imply adding the IR type, but only to APFloat

@tschuett
Copy link

For reference the last type added:
#95392

@tschuett tschuett changed the title Add f8E4M3 IEEE 754 type to llvm [APFloat] Add support for f8E4M3 IEEE 754 type Jun 30, 2024
@tschuett tschuett self-requested a review June 30, 2024 07:24
@apivovarov apivovarov force-pushed the f8E4M3_apfloat_only branch from 639ca43 to 2c5a29e Compare June 30, 2024 07:42
@apivovarov
Copy link
Member Author

Changed commit message to "[APFloat] Add support for f8E4M3 IEEE 754 type" to match PR title. And Rebased

@apivovarov apivovarov force-pushed the f8E4M3_apfloat_only branch from 2c5a29e to 42cfec8 Compare July 1, 2024 19:58
@apivovarov
Copy link
Member Author

Thorsten, Matt, looks like this PR was lost. Can you look at this PR one more time? @tschuett @arsenm

@apivovarov apivovarov force-pushed the f8E4M3_apfloat_only branch from 42cfec8 to 9b2e06f Compare July 14, 2024 05:12
@apivovarov
Copy link
Member Author

apivovarov commented Jul 15, 2024

Hi @tschuett, @arsenm
I wanted to kindly request your review on this f8E4M3 PR when you have a moment. All CI checks have passed.

@apivovarov
Copy link
Member Author

@ThomasRaoux Could you please help with reviewing this PR? It seems that this PR was overlooked.

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

@apivovarov
Copy link
Member Author

If there are no further questions, comments, or objections, I will merge this PR in a few hours.

@apivovarov apivovarov merged commit f363317 into llvm:main Jul 18, 2024
7 checks passed
Harini0924 pushed a commit to Harini0924/llvm-project that referenced this pull request Jul 22, 2024
This PR adds `f8E4M3` type to APFloat.

`f8E4M3` type  follows IEEE 754 convention

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

Related PRs:
- [PR-97118](llvm#97118) Add f8E4M3
IEEE 754 type to mlir
apivovarov added a commit that referenced this pull request Jul 23, 2024
This PR adds `f8E4M3` type to mlir.

`f8E4M3` type  follows IEEE 754 convention

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

Related PRs:
- [PR-97179](#97179) [APFloat]
Add support for f8E4M3 IEEE 754 type
sgundapa pushed a commit to sgundapa/upstream_effort that referenced this pull request Jul 23, 2024
This PR adds `f8E4M3` type to APFloat.

`f8E4M3` type  follows IEEE 754 convention

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

Related PRs:
- [PR-97118](llvm#97118) Add f8E4M3
IEEE 754 type to mlir
sgundapa pushed a commit to sgundapa/upstream_effort that referenced this pull request Jul 23, 2024
This PR adds `f8E4M3` type to mlir.

`f8E4M3` type  follows IEEE 754 convention

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

Related PRs:
- [PR-97179](llvm#97179) [APFloat]
Add support for f8E4M3 IEEE 754 type
banach-space pushed a commit to banach-space/llvm-project that referenced this pull request Aug 7, 2024
This PR adds `f8E4M3` type to APFloat.

`f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

Related PRs:
- [PR-97179](llvm#97179) [APFloat]
Add support for f8E4M3 IEEE 754 type
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 3, 2024
### Summary
This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point
types to StableHLO.
Feedback welcome, see [RFC: Float8E4M3 and
Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md)
for more details.

### References and Links
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- [RFC: FP8 in
StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and
Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
- StableHLO [PR-2482](#2482)
Add f8E4M3 and f8E3M4 types support
- [Amazon EC2 Trn1
Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 4, 2024
This PR adds f8E4M3 and f8E3M4 types support.

f8E4M3 and f8E3M4 types follow IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](#2486)
[RFC] Add f8E4M3 and f8E3M4 types support
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
copybara-service bot pushed a commit to google/tsl that referenced this pull request Sep 30, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Sep 30, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 30, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 1, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 1, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 1, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 2, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 2, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
copybara-service bot pushed a commit to google/tsl that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
copybara-service bot pushed a commit to tensorflow/mlir-hlo that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Oct 2, 2024
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 681551979
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 2, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

PiperOrigin-RevId: 681551979
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:adt llvm:support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants