Skip to content

Commit

Permalink
Update BF16 amp list (#39304)
Browse files Browse the repository at this point in the history
* amp list updated

* tests updated

* gray list updated

* amp list updated

* test updated
  • Loading branch information
arlesniak authored Feb 7, 2022
1 parent ebd1474 commit 0c43ce2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
15 changes: 9 additions & 6 deletions python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,18 @@ def _update_list(self):
bf16_initializer_list = {'fill_constant', 'uniform_random'}

# always bf16
bf16_list = {'elementwise_add', 'mul'}
bf16_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}

# depends on the prev_op type
gray_list = {
'cast',
'fill_constant',
'reduce_mean',
'reshape2',
'scale',
'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div',
'relu', 'layer_norm', 'slice', 'concat', 'uniform_random', 'reshape2',
'transpose2', 'pool2d', 'sigmoid', 'cast', 'scale', 'fill_constant', 'split'
}

_, _, _sys_unsupported_bf16_list = core.op_supported_infos(
Expand Down
16 changes: 8 additions & 8 deletions python/paddle/fluid/contrib/tests/test_bf16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ def test_amp_lists_3(self):
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})

def test_amp_lists_4(self):
# 4. w=None, b={'elementwise_add'}
self.bf16_list.remove('elementwise_add')
self.fp32_list.add('elementwise_add')
# 4. w=None, b={'matmul_v2'}
self.bf16_list.remove('matmul_v2')
self.fp32_list.add('matmul_v2')

self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
custom_fp32_list={'matmul_v2'})

def test_amp_lists_5(self):
# 5. w=None, b={'elementwise_add'}
self.fp32_list.add('elementwise_add')
self.bf16_list.remove('elementwise_add')
# 5. w=None, b={'matmul_v2'}
self.fp32_list.add('matmul_v2')
self.bf16_list.remove('matmul_v2')

self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
custom_fp32_list={'matmul_v2'})

def test_amp_lists_6(self):
# 6. w=None, b={'lstm'}
Expand Down
26 changes: 23 additions & 3 deletions python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,28 @@
import contextlib
import unittest
import numpy as np
import struct
import paddle.fluid.layers as layers
import paddle.static.amp as amp
from paddle.fluid import core

paddle.enable_static()


def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32])(in_list.flat)
return np.reshape(out, in_list.shape)
else:
return in_list


cutf = convert_uint16_to_float


@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestModelCastBF16(unittest.TestCase):
Expand Down Expand Up @@ -111,10 +126,13 @@ def _graph_common(self, _amp_fun, startup_prog=None):
'tt_bf16': nn_bf16,
},
fetch_list=[ret_bf16, ret, ret_fp32bf16],
amp_fun=lambda prog: amp.bf16.rewrite_program_bf16(prog))
amp_fun=_amp_fun,
startup_prog=startup_prog)

self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2))
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2))
self.assertTrue(
np.allclose(cutf(static_ret_bf16), cutf(static_ret), 1e-2))
self.assertTrue(
np.allclose(cutf(static_ret_bf16), cutf(ret_fp32bf16), 1e-2))

with self.static_graph():
t = layers.data(name='t', shape=[size, size], dtype='float32')
Expand All @@ -141,6 +159,7 @@ def test_graph_rewrite(self):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
))

Expand All @@ -149,6 +168,7 @@ def test_graph_cast(self):
prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True
), startup_prog=fluid.default_startup_program())
Expand Down

0 comments on commit 0c43ce2

Please sign in to comment.