Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[WIP] AMP support for numpy ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladimir Cherepanov committed Sep 23, 2020
1 parent 2697573 commit 1a31939
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 287 deletions.
128 changes: 83 additions & 45 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import ctypes
import logging
import contextlib
import sys
import numpy as np

from mxnet import numpy
from ... import symbol
from ...context import gpu
from ...symbol import Symbol
Expand All @@ -41,28 +43,28 @@
from ...base import c_str_array, SymbolHandle, check_call, _LIB, mx_uint, c_array_buf
from ... import optimizer as opt
from .loss_scaler import LossScaler
from ...operator import get_all_registered_operators_grouped

bfloat16 = np.dtype([('bfloat16', np.uint16)])

def _cast_symbol_NDArray(s, dtype):
float_types_gpu = (np.float16, np.float32)
float_types_cpu = (bfloat16, np.float32)
if isinstance(s, Symbol):
return symbol.amp_cast(s, dtype=dtype)
elif isinstance(s, NDArray):
if (s.dtype != dtype and s.dtype in float_types_gpu and s.context.device_type != 'cpu'):
return ndarray.amp_cast(s, dtype=dtype)
elif (s.dtype != dtype and s.dtype in float_types_cpu and s.context.device_type == 'cpu'):
return ndarray.amp_cast(s, dtype=dtype)
else:
return s
else:
return s
float_types_gpu = (np.float16, np.float32)
float_types_cpu = (bfloat16, np.float32)

def _get_fun_to_wrap(name, module, submodule_dict):
def _cast_symbol_NDArray(s, dtype, is_numpy_module=False):
if isinstance(s, Symbol):
amp_cast = symbol.numpy._internal.amp_cast if is_numpy_module else symbol.amp_cast
return amp_cast(s, dtype=dtype)
if isinstance(s, NDArray):
amp_cast = ndarray.numpy._internal.amp_cast if is_numpy_module else ndarray.amp_cast
if s.dtype != dtype and (s.dtype in float_types_gpu and s.context.device_type != 'cpu' or
s.dtype in float_types_cpu and s.context.device_type == 'cpu'):
return amp_cast(s, dtype=dtype)
return s

def _get_nd_fun_to_wrap(name, module, submodule_dict):
module_internal = getattr(module, "_internal")
prefix = base._get_op_name_prefix(name)
if len(prefix) > 0:
if prefix:
if prefix != '_random_' or name.endswith('_like'):
func_name = name[len(prefix):]
cur_module = submodule_dict[prefix]
Expand All @@ -77,8 +79,23 @@ def _get_fun_to_wrap(name, module, submodule_dict):
cur_module = module
return func_name, cur_module

def _wrap_symbol_functions(module, target_dtype, target_precision_ops=None,
conditional_fp32_ops=None, fp32_ops=None):
def _get_np_fun_to_wrap(name, ns_prefix):
if name.startswith('_np_'):
return name[4:], sys.modules[f'{ns_prefix}.numpy._op']
if name.startswith('_npi_'):
return name[5:], sys.modules[f'{ns_prefix}.numpy._internal']
if name.startswith('_npx_'):
return name[5:], sys.modules[f'{ns_prefix}.numpy_extension._op']
assert False
return None # for pylint

def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases, get_cond_aliases,
get_fun_to_wrap, target_precision_ops=None, conditional_fp32_ops=None,
fp32_ops=None):

nd_mod = ndarray.numpy._internal if is_numpy_module else ndarray
sy_mod = symbol.numpy._internal if is_numpy_module else symbol

def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
def _new_fun(*args, **kwargs):
if cond_arg is not None:
Expand All @@ -91,20 +108,22 @@ def _new_fun(*args, **kwargs):
if fp32_param[i]:
new_args.append(x)
else:
new_args.append(_cast_symbol_NDArray(x, target_dtype))
new_args.append(_cast_symbol_NDArray(x, target_dtype, is_numpy_module))
else:
new_args = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype), args))
new_args = list(map(
lambda x: _cast_symbol_NDArray(x, target_dtype, is_numpy_module), args))
args = tuple(new_args)
if fp32_param:
new_kwargs = {}
for k, v in kwargs.items():
if k in fp32_param:
new_kwargs[k] = v
else:
new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype)
new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype, is_numpy_module)
kwargs = new_kwargs
else:
kwargs = {k: _cast_symbol_NDArray(v, target_dtype) for k, v in kwargs.items()}
kwargs = {k: _cast_symbol_NDArray(v, target_dtype, is_numpy_module)
for k, v in kwargs.items()}
return f(*args, **kwargs)
_new_fun.__name__ = f.__name__
_new_fun.__module__ = f.__module__
Expand All @@ -126,10 +145,10 @@ def _new_fun(*args, **kwargs):
if (x.name in aux) or fp32_param[i]:
new_inputs.append(x)
else:
new_inputs.append(_cast_symbol_NDArray(x, target_dtype))
new_inputs.append(_cast_symbol_NDArray(x, target_dtype, is_numpy_module))
inputs = new_inputs
else:
inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype)
inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype, is_numpy_module)
if x.name not in aux else x, inputs))
atomic_sym = sym._gen_atomic_symbol()
wrapped_sym = atomic_sym(*inputs)
Expand Down Expand Up @@ -162,11 +181,11 @@ def _new_fun(*args, **kwargs):
widest_type = np.float32
for arr, index, arg in symbols:
if arg.dtype != widest_type and arg.dtype == target_dtype:
arr[index] = ndarray.amp_cast(arg, dtype=widest_type)
arr[index] = nd_mod.amp_cast(arg, dtype=widest_type)
else:
# Symbol case
sym_to_check = list(map(lambda x: x[2], symbols))
casted_syms = symbol.amp_multicast(*sym_to_check, num_outputs=len(sym_to_check))
casted_syms = sy_mod.amp_multicast(*sym_to_check, num_outputs=len(sym_to_check))
symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]),
zip(symbols, casted_syms)))
for arr, index, arg in symbols:
Expand All @@ -180,16 +199,12 @@ def _new_fun(*args, **kwargs):

_wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib) else _ndarray_wrapper

submodule_dict = {}
for op_name_prefix in base._OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
getattr(module, op_name_prefix[1:-1])
fp32_param_list = list_lp16_use_fp32_params(target_dtype)
wrap_list = target_precision_ops if target_precision_ops is not None \
else list_lp16_ops(target_dtype)
for fun_name in wrap_list:
for fun_name in get_aliases(wrap_list):
try:
fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
fp32_param = fp32_param_list[fun_name] if (fp32_param_list and fun_name in fp32_param_list) else None
setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
Expand All @@ -199,9 +214,9 @@ def _new_fun(*args, **kwargs):
raise

wrap_list = fp32_ops if fp32_ops is not None else list_fp32_ops(target_dtype)
for fun_name in wrap_list:
for fun_name in get_aliases(wrap_list):
try:
fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
if cur_module == module:
Expand All @@ -211,9 +226,9 @@ def _new_fun(*args, **kwargs):

wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
else list_conditional_fp32_ops(target_dtype)
for fun_name, arg, arg_values in wrap_list:
for fun_name, arg, arg_values in get_cond_aliases(wrap_list):
try:
fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
if cur_module == module:
Expand All @@ -222,9 +237,9 @@ def _new_fun(*args, **kwargs):
raise


for fun_name in list_widest_type_cast(target_dtype):
for fun_name in get_aliases(list_widest_type_cast(target_dtype)):
try:
fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
fun_name, cur_module = get_fun_to_wrap(fun_name, module)
f_to_wrap = getattr(cur_module, fun_name)
setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
if cur_module == module:
Expand Down Expand Up @@ -310,13 +325,37 @@ def init(target_dtype='float16', target_precision_ops=None,
target_dtype = bfloat16
else:
target_dtype = np.dtype(target_dtype)
_wrap_symbol_functions(symbol, target_dtype, target_precision_ops,
conditional_fp32_ops, fp32_ops)
_wrap_symbol_functions(ndarray, target_dtype, target_precision_ops,
conditional_fp32_ops, fp32_ops)

ops = get_all_registered_operators_grouped()
get_aliases_nd = lambda l: [a for op in l for a in ops[op] if not base._is_np_op(a)]
get_aliases_np = lambda l: [a for op in l for a in ops[op] if base._is_np_op(a)]
get_aliases_np_pub = lambda l: [a for op in l for a in ops[op]
if a.startswith(('_np_', '_npx_'))]
get_cond_aliases_nd = lambda l: [(a, *rest) for op, *rest in l for a in ops[op]
if not base._is_np_op(a)]
get_cond_aliases_np = lambda l: [(a, *rest) for op, *rest in l for a in ops[op]
if base._is_np_op(a)]
get_cond_aliases_np_pub = lambda l: [(a, *rest) for op, *rest in l for a in ops[op]
if a.startswith(('_np_', '_npx_'))]
sy_submodules = {p:getattr(symbol, p[1:-1]) for p in base._OP_NAME_PREFIX_LIST}
get_sy_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod, sy_submodules)
nd_submodules = {p:getattr(ndarray, p[1:-1]) for p in base._OP_NAME_PREFIX_LIST}
get_nd_fun = lambda fun, mod: _get_nd_fun_to_wrap(fun, mod, nd_submodules)
get_np_sy_fun = lambda fun, mod: _get_np_fun_to_wrap(fun, "mxnet.symbol")
get_np_nd_fun = lambda fun, mod: _get_np_fun_to_wrap(fun, "mxnet.ndarray")
get_np_fun = lambda fun, mode: _get_np_fun_to_wrap(fun, "mxnet")
todo = [
(symbol, False, get_aliases_nd, get_cond_aliases_nd, get_sy_fun),
(ndarray, False, get_aliases_nd, get_cond_aliases_nd, get_nd_fun),
(symbol.numpy, True, get_aliases_np, get_cond_aliases_np, get_np_sy_fun),
(ndarray.numpy, True, get_aliases_np, get_cond_aliases_np, get_np_nd_fun),
(numpy, True, get_aliases_np_pub, get_cond_aliases_np_pub, get_np_fun),
]
_loss_scaler = LossScaler()
_wrap_loss_output_functions(ndarray, _loss_scaler, target_dtype)
_wrap_loss_output_functions(symbol, _loss_scaler, target_dtype)
for module, is_numpy, get_aliases, get_cond_aliases, get_fun in todo:
_wrap_module_functions(module, is_numpy, target_dtype, get_aliases, get_cond_aliases,
get_fun, target_precision_ops, conditional_fp32_ops, fp32_ops)
_wrap_loss_output_functions(module, _loss_scaler, target_dtype)

def init_trainer(optimizer_or_trainer):
"""Initialize trainer or optimizer to work with AMP dynamic loss scaling.
Expand All @@ -340,7 +379,6 @@ def init_trainer(optimizer_or_trainer):
optimizer_or_trainer._amp_loss_scaler = loss_scaler
optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
elif isinstance(optimizer_or_trainer, opt.Optimizer):
# TODO(ptredak): make it work with the optimizer
raise TypeError("AMP is currently only compatible with Gluon Trainer")
else:
raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
Expand Down
Loading

0 comments on commit 1a31939

Please sign in to comment.