Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add bfloat16 floating-point format support based on AMP #17265

Merged
merged 61 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
350da26
Add Bfloat16
ZhennanQin May 20, 2019
b39d3e1
mshadow support bf16
ElaineBao Oct 31, 2019
9c44b66
rebase bf16 mkldnn1.0
ElaineBao Nov 5, 2019
da7118c
support bf16 gemm
xinyu-intel Nov 12, 2019
c220bfc
resolve fp32 ip bwd bug
xinyu-intel Nov 18, 2019
f1055e5
add other bf16 ops
ElaineBao Nov 20, 2019
b1f8b94
change func name from fp16 to lp16 (low precision 16), to include bf16
ElaineBao Nov 20, 2019
ac3edaa
add amp_cast bf16 support for ndarray
ElaineBao Nov 21, 2019
16dd430
fix executor copy_params
ElaineBao Nov 27, 2019
99e31e9
add test case for bf16
ElaineBao Nov 27, 2019
77a6a6f
remove numpy dtype hook for bf16
ElaineBao Nov 28, 2019
22b70af
add bf16 type support
ElaineBao Dec 9, 2019
e9dd678
rebase to mxnet master
ElaineBao Dec 9, 2019
97372fd
add single conv test
ElaineBao Dec 10, 2019
c4aec63
fix symbolic inference
ElaineBao Dec 10, 2019
7a7ab3a
add dtype check when copy
ElaineBao Dec 11, 2019
e3fced6
add single conv and bn test
ElaineBao Dec 11, 2019
9b68d07
skip fp16 amp_cast test in cpu
ElaineBao Dec 16, 2019
89eaba1
Fix resnet50 first convolution
ZhennanQin Dec 17, 2019
2cb8969
Skip first convolution for bfloat16
ZhennanQin Dec 20, 2019
db28e3a
support bf16 fallback compute
rongzha1 Dec 23, 2019
e6d2b69
recover origin test
rongzha1 Dec 23, 2019
704ac96
fix bf16 bn test, enhance assert_almost_equal_with_err
ElaineBao Dec 23, 2019
f931fcb
add some bf16 unittests
wuxun-zhang Dec 19, 2019
1103bb7
using assert_almost_equal_with_err for fallback bn test
rongzha1 Dec 25, 2019
98f8b07
add relu6 bf16 support
ElaineBao Dec 29, 2019
bdb2483
fix lint
rongzha1 Jan 3, 2020
b757cc4
fix subgraph conv with data=0
ElaineBao Jan 6, 2020
b27dec5
mkldnn doesn't support 0 dim tensor
rongzha1 Jan 7, 2020
f37b3fe
rm dtype check when copy
rongzha1 Jan 7, 2020
b3e5deb
using bf16 tvm
rongzha1 Jan 8, 2020
d88bc3e
rm bf16 mnist demo
rongzha1 Jan 10, 2020
bf6c727
use official tvm
rongzha1 Jan 10, 2020
a8eab98
change function name; fix lint error
rongzha1 Jan 14, 2020
f46fcdc
fix clang check error:conditional expression is ambiguous; 'float' ca…
rongzha1 Jan 14, 2020
e39e3a6
nvcc compiler build pass
rongzha1 Jan 15, 2020
b554998
fix gpu amp cast symbol error
rongzha1 Jan 16, 2020
7db4f61
fix mnist training error
rongzha1 Jan 17, 2020
c04fe99
fix cpp test: Engine.VarVersion error
rongzha1 Jan 17, 2020
b3306c9
workaround cpp failed test mkldnn fc bwd
rongzha1 Jan 19, 2020
06a32fe
to fix mkldnn test_mkldnn_ndarray_slice error
rongzha1 Jan 20, 2020
1b090b5
1. move some code from to np_broadcast_reduce_op_value.cc to np_broad…
rongzha1 Jan 21, 2020
44f133b
use official dlpack
rongzha1 Jan 21, 2020
04a1402
rename np_broadcast_reduce_op_value_part2.cc and add some description
rongzha1 Jan 21, 2020
afe1c09
1. update dlpack url in .gitmodule
rongzha1 Jan 23, 2020
1027985
fix remaining NodePtr due to tvm update
rongzha1 Feb 10, 2020
0d83536
mv some code from mxnet_op.h to mxnet_op_kernel_assign.h to avoid WIN…
rongzha1 Feb 10, 2020
592e67f
fix WIN CPU build fail:compiler is out of heap space in pass 2
rongzha1 Feb 10, 2020
e934607
fix WIN build fail
rongzha1 Feb 11, 2020
ce35638
fix lint
rongzha1 Feb 11, 2020
4dfd91b
add print for test bf16_concat
rongzha1 Feb 12, 2020
405e2aa
fix bf16 test fail
rongzha1 Feb 12, 2020
e0246ae
disable bf16 concat test
rongzha1 Feb 12, 2020
043d315
tmp skip to root cause edge test halt
rongzha1 Feb 13, 2020
052bf79
fix bf16_bn test error
rongzha1 Feb 13, 2020
1880d95
enable test_bulk
rongzha1 Feb 14, 2020
5edd735
tmp rm bf16 to locate edge error
rongzha1 Feb 14, 2020
3bad0fe
Revert "tmp rm bf16 to locate edge error"
rongzha1 Feb 15, 2020
4a43b1d
add Apache license header
rongzha1 Feb 15, 2020
9a37a0a
trigger CI
rongzha1 Feb 15, 2020
df090c1
add robust for test bf16 bn
rongzha1 Feb 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
161 changes: 160 additions & 1 deletion 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,32 @@ extern "C" {

#include "./half.h"
#include "./half2.h"
#include "./bfloat.h"
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
return float(a) OP float(b); /* NOLINT(*) */ \
} \
MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \
return float(a) OP float(b); /* NOLINT(*) */ \
}

/*! \brief overloaded + operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, +)
/*! \brief overloaded - operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, -)
/*! \brief overloaded * operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, *)
/*! \brief overloaded / operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(float, /)
/*! \brief overloaded > operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, >)
/*! \brief overloaded < operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, <)
/*! \brief overloaded >= operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, >=)
/*! \brief overloaded <= operator between half_t and bf16_t */
MSHADOW_HALF_BF_OPERATOR(bool, <=)

#include "./logging.h"
/*! \brief namespace for mshadow */
namespace mshadow {
Expand Down Expand Up @@ -312,6 +338,11 @@ enum TypeFlag {
kInt8 = 5,
kInt64 = 6,
kBool = 7,
kInt16 = 8,
kUint16 = 9,
kUint32 = 10,
kUint64 = 11,
Comment on lines +341 to +344
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding those additional types here? No operator supports them anyway, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to align the definition with DLPack. Otherwise we have to preserve those numbers. Even through we don't use them currently, it's no harm to add them.

kBfloat16 = 12
};

template<typename DType>
Expand Down Expand Up @@ -365,6 +396,11 @@ struct DataType<half::half2_t> {
static const int kLanes = 2;
};
template<>
struct DataType<bfloat::bf16_t> {
static const int kFlag = kBfloat16;
static const int kLanes = 1;
};
template<>
struct DataType<uint8_t> {
static const int kFlag = kUint8;
static const int kLanes = 1;
Expand Down Expand Up @@ -688,6 +724,11 @@ template<>
MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
return MSHADOW_HALF_MIN;
}
/*! \brief minimum value of bf16 */
template<>
MSHADOW_XINLINE bfloat::bf16_t MinValue<bfloat::bf16_t>(void) {
return MSHADOW_BF16_MIN;
}
/*! \brief minimum value of uint8_t */
template<>
MSHADOW_XINLINE uint8_t MinValue<uint8_t>(void) {
Expand Down Expand Up @@ -765,6 +806,11 @@ template<>
MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) {
return MSHADOW_HALF_MAX;
}
/*! \brief maximum value of bf16 */
template<>
MSHADOW_XINLINE bfloat::bf16_t MaxValue<bfloat::bf16_t>(void) {
return MSHADOW_BF16_MAX;
}
/*! \brief maximum value of uint8_t */
template<>
MSHADOW_XINLINE uint8_t MaxValue<uint8_t>(void) {
Expand Down Expand Up @@ -998,6 +1044,7 @@ struct minimum {
};
} // namespace red

#ifndef __NVCC__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it - can we make a similar thing as in the case of fp16 for CPUs that did not support F16C instructions (i.e. code that runs but may be slower than the code for hardware natively supporting bfloat16)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can implement atomicAdd (which seems to be the problem you are facing)with atomicCAS like this: https://github.com/apache/incubator-mxnet/blob/master/src/common/cuda_utils.h#L702-L721

Copy link
Contributor

@ZhennanQin ZhennanQin Jan 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have enough background / knowledge to enable Bfloat16 on GPU side. So probably we can't make the change you proposed. Alternately, any code refactoring on GPU side is welcome. you may change this as you want in following PR.

#define MSHADOW_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
Expand All @@ -1018,6 +1065,12 @@ struct minimum {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBfloat16: \
{ \
typedef mshadow::bfloat::bf16_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
Expand Down Expand Up @@ -1045,6 +1098,55 @@ struct minimum {
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
#else
#define MSHADOW_TYPE_SWITCH(type, DType, ...) \
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
#endif

#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
switch (type) { \
Expand Down Expand Up @@ -1147,6 +1249,7 @@ struct minimum {
LOG(FATAL) << "Unknown type enum " << type; \
}

#ifndef __NVCC__
#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
switch (type$) { \
case mshadow::kFloat32: \
Expand All @@ -1170,6 +1273,13 @@ struct minimum {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBfloat16: \
{ \
typedef mshadow::bfloat::bf16_t DType$; \
typedef float DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
Expand All @@ -1189,7 +1299,50 @@ struct minimum {
default: \
LOG(FATAL) << "Unknown type enum " << type$; \
}

#else
#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
switch (type$) { \
case mshadow::kFloat32: \
{ \
typedef float DType$; \
typedef float DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType$; \
typedef double DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType$; \
typedef float DLargeType$; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32";\
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64";\
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type$; \
}
#endif
#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
switch (layout) { \
case mshadow::kNCHW: \
Expand Down Expand Up @@ -1256,6 +1409,12 @@ struct minimum {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBfloat16: \
{ \
typedef mshadow::bfloat::bf16_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
Expand Down
Loading