Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix #17164 symbolblock with BatchNorm inside during cast to fp16 (#17212
Browse files Browse the repository at this point in the history
)

* fix symbolblock with bn+fp16

* add unittest

* fix

* remove unused

* fix lint
  • Loading branch information
zhreshold authored and ptrendx committed Jan 6, 2020
1 parent 1612533 commit bd7eedf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
25 changes: 24 additions & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ, too-many-lines
# pylint: disable= arguments-differ, too-many-lines, reimported
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']

Expand All @@ -25,6 +25,7 @@
import warnings
import re
from collections import OrderedDict, defaultdict
import numpy as np

from ..base import mx_real_t, MXNetError
from .. import symbol, ndarray, initializer, np_symbol
Expand Down Expand Up @@ -1353,6 +1354,28 @@ def _clear_cached_op(self):
def cast(self, dtype):
self._clear_cached_op()
super(SymbolBlock, self).cast(dtype)
if np.dtype(dtype).name == 'float16':
# correct BatchNorm types back to float32 due to its special requirement
out = self._cached_graph[1]
params_list = out.get_internals().list_inputs()
for node in params_list:
if node.endswith('running_var'):
prefix = node[:-11]
sibs = [prefix + t for t in ('running_mean', 'gamma', 'beta')]
is_bn = all(p in params_list for p in sibs)
if is_bn:
self.params.get(node).cast('float32')
for sib in sibs:
self.params.get(sib).cast('float32')
if node.endswith('moving_var'):
# another convention used
prefix = node[:-10]
sibs = [prefix + t for t in ('moving_mean', 'gamma', 'beta')]
is_bn = all(p in params_list for p in sibs)
if is_bn:
self.params.get(node).cast('float32')
for sib in sibs:
self.params.get(sib).cast('float32')

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
Expand Down
18 changes: 18 additions & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,24 @@ def hybrid_forward(self, F, a, b):
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()),
mx.nd.ones((10,), ctx=mx.cpu())))

@with_seed()
def test_symbol_block_symbolic_bn_fp16_cast():
with mx.gpu(0):
net = mx.gluon.nn.HybridSequential()
sym = mx.sym.var('data')
conv = mx.sym.Convolution(sym, kernel=(3, 3), num_filter=16)
bn = mx.sym.BatchNorm(conv, name='bn_test')
internals = bn.get_internals()
net.add(mx.gluon.nn.SymbolBlock([internals['bn_test_output']], [mx.sym.var('data')]))
net.add(mx.gluon.nn.Conv2D(10, kernel_size=1))
net.initialize()
x = mx.nd.zeros((1, 3, 32, 32), dtype='float32')
y = net(x)
assert np.dtype(y.dtype).name == 'float32'
net.cast('float16')
x = x.astype('float16')
y1 = net(x)
assert np.dtype(y1.dtype).name == 'float16'

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit bd7eedf

Please sign in to comment.