-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add bfloat16 floating-point format support based on AMP #17265
Changes from all commits
350da26
b39d3e1
9c44b66
da7118c
c220bfc
f1055e5
b1f8b94
ac3edaa
16dd430
99e31e9
77a6a6f
22b70af
e9dd678
97372fd
c4aec63
7a7ab3a
e3fced6
9b68d07
89eaba1
2cb8969
db28e3a
e6d2b69
704ac96
f931fcb
1103bb7
98f8b07
bdb2483
b757cc4
b27dec5
f37b3fe
b3e5deb
d88bc3e
bf6c727
a8eab98
f46fcdc
e39e3a6
b554998
7db4f61
c04fe99
b3306c9
06a32fe
1b090b5
44f133b
04a1402
afe1c09
1027985
0d83536
592e67f
e934607
ce35638
4dfd91b
405e2aa
e0246ae
043d315
052bf79
1880d95
5edd735
3bad0fe
4a43b1d
9a37a0a
df090c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -277,6 +277,32 @@ extern "C" { | |
|
||
#include "./half.h" | ||
#include "./half2.h" | ||
#include "./bfloat.h" | ||
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \ | ||
MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \ | ||
return float(a) OP float(b); /* NOLINT(*) */ \ | ||
} \ | ||
MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \ | ||
return float(a) OP float(b); /* NOLINT(*) */ \ | ||
} | ||
|
||
/*! \brief overloaded + operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(float, +) | ||
/*! \brief overloaded - operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(float, -) | ||
/*! \brief overloaded * operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(float, *) | ||
/*! \brief overloaded / operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(float, /) | ||
/*! \brief overloaded > operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(bool, >) | ||
/*! \brief overloaded < operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(bool, <) | ||
/*! \brief overloaded >= operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(bool, >=) | ||
/*! \brief overloaded <= operator between half_t and bf16_t */ | ||
MSHADOW_HALF_BF_OPERATOR(bool, <=) | ||
|
||
#include "./logging.h" | ||
/*! \brief namespace for mshadow */ | ||
namespace mshadow { | ||
|
@@ -312,6 +338,11 @@ enum TypeFlag { | |
kInt8 = 5, | ||
kInt64 = 6, | ||
kBool = 7, | ||
kInt16 = 8, | ||
kUint16 = 9, | ||
kUint32 = 10, | ||
kUint64 = 11, | ||
kBfloat16 = 12 | ||
}; | ||
|
||
template<typename DType> | ||
|
@@ -365,6 +396,11 @@ struct DataType<half::half2_t> { | |
static const int kLanes = 2; | ||
}; | ||
template<> | ||
struct DataType<bfloat::bf16_t> { | ||
static const int kFlag = kBfloat16; | ||
static const int kLanes = 1; | ||
}; | ||
template<> | ||
struct DataType<uint8_t> { | ||
static const int kFlag = kUint8; | ||
static const int kLanes = 1; | ||
|
@@ -688,6 +724,11 @@ template<> | |
MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) { | ||
return MSHADOW_HALF_MIN; | ||
} | ||
/*! \brief minimum value of bf16 */ | ||
template<> | ||
MSHADOW_XINLINE bfloat::bf16_t MinValue<bfloat::bf16_t>(void) { | ||
return MSHADOW_BF16_MIN; | ||
} | ||
/*! \brief minimum value of uint8_t */ | ||
template<> | ||
MSHADOW_XINLINE uint8_t MinValue<uint8_t>(void) { | ||
|
@@ -765,6 +806,11 @@ template<> | |
MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) { | ||
return MSHADOW_HALF_MAX; | ||
} | ||
/*! \brief maximum value of bf16 */ | ||
template<> | ||
MSHADOW_XINLINE bfloat::bf16_t MaxValue<bfloat::bf16_t>(void) { | ||
return MSHADOW_BF16_MAX; | ||
} | ||
/*! \brief maximum value of uint8_t */ | ||
template<> | ||
MSHADOW_XINLINE uint8_t MaxValue<uint8_t>(void) { | ||
|
@@ -998,6 +1044,7 @@ struct minimum { | |
}; | ||
} // namespace red | ||
|
||
#ifndef __NVCC__ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like it - can we make a similar thing as in the case of fp16 for CPUs that did not support F16C instructions (i.e. code that runs but may be slower than the code for hardware natively supporting bfloat16)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have enough background / knowledge to enable Bfloat16 on GPU side. So probably we can't make the change you proposed. Alternately, any code refactoring on GPU side is welcome. you may change this as you want in following PR. |
||
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \ | ||
switch (type) { \ | ||
case mshadow::kFloat32: \ | ||
|
@@ -1018,6 +1065,12 @@ struct minimum { | |
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kBfloat16: \ | ||
{ \ | ||
typedef mshadow::bfloat::bf16_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kUint8: \ | ||
{ \ | ||
typedef uint8_t DType; \ | ||
|
@@ -1045,6 +1098,55 @@ struct minimum { | |
default: \ | ||
LOG(FATAL) << "Unknown type enum " << type; \ | ||
} | ||
#else | ||
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \ | ||
switch (type) { \ | ||
case mshadow::kFloat32: \ | ||
{ \ | ||
typedef float DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kFloat64: \ | ||
{ \ | ||
typedef double DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kFloat16: \ | ||
{ \ | ||
typedef mshadow::half::half_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kUint8: \ | ||
{ \ | ||
typedef uint8_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kInt8: \ | ||
{ \ | ||
typedef int8_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kInt32: \ | ||
{ \ | ||
typedef int32_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kInt64: \ | ||
{ \ | ||
typedef int64_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
default: \ | ||
LOG(FATAL) << "Unknown type enum " << type; \ | ||
} | ||
#endif | ||
|
||
#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ | ||
switch (type) { \ | ||
|
@@ -1147,6 +1249,7 @@ struct minimum { | |
LOG(FATAL) << "Unknown type enum " << type; \ | ||
} | ||
|
||
#ifndef __NVCC__ | ||
#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ | ||
switch (type$) { \ | ||
case mshadow::kFloat32: \ | ||
|
@@ -1170,6 +1273,13 @@ struct minimum { | |
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kBfloat16: \ | ||
{ \ | ||
typedef mshadow::bfloat::bf16_t DType$; \ | ||
typedef float DLargeType$; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kUint8: \ | ||
LOG(FATAL) << "This operation only support " \ | ||
"floating point types not uint8"; \ | ||
|
@@ -1189,7 +1299,50 @@ struct minimum { | |
default: \ | ||
LOG(FATAL) << "Unknown type enum " << type$; \ | ||
} | ||
|
||
#else | ||
#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ | ||
switch (type$) { \ | ||
case mshadow::kFloat32: \ | ||
{ \ | ||
typedef float DType$; \ | ||
typedef float DLargeType$; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kFloat64: \ | ||
{ \ | ||
typedef double DType$; \ | ||
typedef double DLargeType$; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kFloat16: \ | ||
{ \ | ||
typedef mshadow::half::half_t DType$; \ | ||
typedef float DLargeType$; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kUint8: \ | ||
LOG(FATAL) << "This operation only support " \ | ||
"floating point types not uint8"; \ | ||
break; \ | ||
case mshadow::kInt8: \ | ||
LOG(FATAL) << "This operation only support " \ | ||
"floating point types not int8"; \ | ||
break; \ | ||
case mshadow::kInt32: \ | ||
LOG(FATAL) << "This operation only support " \ | ||
"floating point types, not int32";\ | ||
break; \ | ||
case mshadow::kInt64: \ | ||
LOG(FATAL) << "This operation only support " \ | ||
"floating point types, not int64";\ | ||
break; \ | ||
default: \ | ||
LOG(FATAL) << "Unknown type enum " << type$; \ | ||
} | ||
#endif | ||
#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ | ||
switch (layout) { \ | ||
case mshadow::kNCHW: \ | ||
|
@@ -1256,6 +1409,12 @@ struct minimum { | |
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kBfloat16: \ | ||
{ \ | ||
typedef mshadow::bfloat::bf16_t DType; \ | ||
{__VA_ARGS__} \ | ||
} \ | ||
break; \ | ||
case mshadow::kUint8: \ | ||
{ \ | ||
typedef uint8_t DType; \ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why adding those additional types here? No operator supports them anyway, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to align the definition with DLPack. Otherwise we have to preserve those numbers. Even through we don't use them currently, it's no harm to add them.