diff --git a/CODEOWNERS b/CODEOWNERS index 74caea2c4cc7..90744501ec71 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -47,7 +47,8 @@ CMakeLists.txt @szha @pllarroy # MXNet CI dev_menu.py @pllarroy -/ci/ @pllarroy +/ci/ @pllarroy @marcoabreu +/docker/ @marcoabreu /tests/ci_build/ @marcoabreu Jenkinsfile @marcoabreu .travis.yml @marcoabreu diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 750a22a806d1..86c20cbc3d87 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -245,6 +245,7 @@ List of Contributors * [Rohit Srivastava](https://github.com/access2rohit) * [Caner Turkmen](https://github.com/canerturkmen) * [Disi A](https://github.com/adis300) +* [Vandana Kannan](https://github.com/vandanavk) Label Bot --------- diff --git a/benchmark/opperf/utils/benchmark_utils.py b/benchmark/opperf/utils/benchmark_utils.py index c530f9bf6177..a6ee38bf9f65 100644 --- a/benchmark/opperf/utils/benchmark_utils.py +++ b/benchmark/opperf/utils/benchmark_utils.py @@ -50,7 +50,7 @@ def _prepare_op_inputs(inputs, run_backward, dtype, ctx): return args_list, kwargs_list -def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, kwargs_list, profiler): +def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, args_list, kwargs_list, profiler): if profiler == 'native': if run_backward: benchmark_helper_func = cpp_profile(nd_forward_backward_and_profile) diff --git a/benchmark/opperf/utils/common_utils.py b/benchmark/opperf/utils/common_utils.py index e657702d3596..fa0331f3468e 100644 --- a/benchmark/opperf/utils/common_utils.py +++ b/benchmark/opperf/utils/common_utils.py @@ -97,6 +97,10 @@ def _prepare_op_benchmark_result(op, op_bench_result, profiler): max_mem_usage = "---" inputs = "---" avg_time = "---" + p50_time = "---" + p90_time = "---" + p99_time = "---" + for key, value in op_bench_result.items(): if "avg_time_forward" in key: avg_forward_time = value @@ -108,12 +112,19 @@ def _prepare_op_benchmark_result(op, op_bench_result, profiler): inputs = value elif "avg_time" in key: avg_time = value + elif "p50_time" in key: + p50_time = value + elif "p90_time" in key: + p90_time = value + elif "p99_time" in key: + p99_time = value + result = "" if profiler == "native": result = "| {} | {} | {} | {} | {} |".format(operator_name, avg_forward_time, avg_backward_time, max_mem_usage, inputs) elif profiler == "python": - result = "| {} | {} | {} |".format(operator_name, avg_time, inputs) + result = "| {} | {} | {} | {} | {} | {} |".format(operator_name, avg_time, p50_time, p90_time, p99_time, inputs) return result @@ -132,8 +143,8 @@ def _prepare_markdown(results, runtime_features=None, profiler='native'): " | Inputs |") elif profiler == 'python': results_markdown.append( - "| Operator | Avg Time (ms) | Inputs |") - results_markdown.append("| :---: | :---: | :---: |") + "| Operator | Avg Time (ms) | P50 Time (ms) | P90 Time (ms) | P99 Time (ms) | Inputs |") + results_markdown.append("| :---: | :---: | :---: | :---: | :---: | :---: |") for op, op_bench_results in sorted(results.items(), key=itemgetter(0)): for op_bench_result in op_bench_results: diff --git a/benchmark/opperf/utils/profiler_utils.py b/benchmark/opperf/utils/profiler_utils.py index 5d4eb5576220..21e2606ab94e 100644 --- a/benchmark/opperf/utils/profiler_utils.py +++ b/benchmark/opperf/utils/profiler_utils.py @@ -17,6 +17,7 @@ import time import functools +import numpy as np from .common_utils import merge_map_list from mxnet import profiler @@ -219,6 +220,9 @@ def python_profile(func): res, timing output. res being result returned after operator execution. profiler output is a dictionary with summary of operation execution. Example output : { "add": [{"avg_time_add": 0.4053089120425284, + 'p50_time_add': 16.761042876169086, + 'p90_time_add': 18.081666342914108, + 'p99_time_add': 19.060144051909447, "inputs": { "lhs": [1024, 1024], "rhs": [1024,1024] @@ -228,10 +232,16 @@ def python_profile(func): @functools.wraps(func) def python_profile_it(*args, **kwargs): - start_time = time.perf_counter() # 1 - res = func(*args, **kwargs) - end_time = time.perf_counter() # 2 - run_time = end_time - start_time # 3 + runs = args[1] + modified_args = (args[0], 1, args[2]) + times = [] + + for _ in range(runs): + start_time = time.perf_counter() # 1 + res = func(*modified_args, **kwargs) + end_time = time.perf_counter() # 2 + run_time = (end_time - start_time)*1000 # 3 + times.append(run_time) # NOTE : same as cpp_profile_it if len(args) > 0: @@ -241,6 +251,15 @@ def python_profile_it(*args, **kwargs): else: raise ValueError("Unable to identify operator name to extract profiler output!") - profiler_output = {'avg_time_'+str(operator_name): run_time} + avg_run_time = np.mean(times) + p50_run_time = np.percentile(times, 50) + p90_run_time = np.percentile(times, 90) + p99_run_time = np.percentile(times, 99) + + profiler_output = {'avg_time_'+str(operator_name): avg_run_time, + 'p50_time_'+str(operator_name): p50_run_time, + 'p90_time_'+str(operator_name): p90_run_time, + 'p99_time_'+str(operator_name): p99_run_time, + } return res, profiler_output return python_profile_it diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md index 5691fb022959..7a1655c74afb 100644 --- a/docs/api/python/contrib/onnx.md +++ b/docs/api/python/contrib/onnx.md @@ -30,7 +30,7 @@ With ONNX format support for MXNet, developers can build and train models with a ``` ### Installation Instructions -- To use this module developers need to **install ONNX**, which requires the protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). **MXNet currently supports ONNX v1.2.1**. Once installed, you can go through the tutorials on how to use this module. +- To use this module developers need to **install ONNX**, which requires the protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). **MXNet currently supports ONNX v1.3**. Once installed, you can go through the tutorials on how to use this module. This document describes all the ONNX-MXNet APIs. diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md index fea746bb02f4..264f69093709 100644 --- a/docs/api/python/symbol/symbol.md +++ b/docs/api/python/symbol/symbol.md @@ -612,6 +612,7 @@ Composite multiple symbols into a new one by an operator. random.normal random.poisson random.randint + random.randn random.shuffle random.uniform mxnet.random.seed diff --git a/docs/tutorials/amp/amp_tutorial.md b/docs/tutorials/amp/amp_tutorial.md index 9da0505e9ff6..2b747c6c82f6 100644 --- a/docs/tutorials/amp/amp_tutorial.md +++ b/docs/tutorials/amp/amp_tutorial.md @@ -258,6 +258,7 @@ To do inference with mixed precision for a trained model in FP32, you can use th Below, we demonstrate for a gluon model and a symbolic model: - Conversion from FP32 model to mixed precision model. - Run inference on the mixed precision model. +- For AMP conversion of bucketing module please refer to [example/rnn/bucketing/README.md](https://github.com/apache/incubator-mxnet/blob/master/example/rnn/bucketing/README.md). ```python with mx.Context(mx.gpu(0)): @@ -336,7 +337,6 @@ with mx.Context(mx.gpu(0)): mod.save_checkpoint("amp_tutorial_model", 0, remove_amp_cast=False) ``` - ## Current limitations of AMP - AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 302a04449885..6c87f58b63e2 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -216,7 +216,8 @@ def save_params(fname, arg_params, aux_params, logger=None): if exclude_first_conv: excluded_sym_names += ['resnetv10_conv0_fwd'] elif args.model.find('resnet') != -1 and args.model.find('v2') != -1: - excluded_sym_names += ['resnetv20_flatten0_flatten0'] + # resnetv20_stage1_batchnorm0_fwd is excluded for the sake of accuracy + excluded_sym_names += ['resnetv20_flatten0_flatten0', 'resnetv20_stage1_batchnorm0_fwd'] if exclude_first_conv: excluded_sym_names += ['resnetv20_conv0_fwd'] elif args.model.find('vgg') != -1: diff --git a/example/rnn/bucketing/README.md b/example/rnn/bucketing/README.md index 707370af5a96..d44b23e69b23 100644 --- a/example/rnn/bucketing/README.md +++ b/example/rnn/bucketing/README.md @@ -55,6 +55,12 @@ You can check this improved [Gluon implementation](http://gluon-nlp.mxnet.io/mod $ python3 [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0,1,2,3 +- To run the mixed precision inference for the trained model, you should use the `--dtype`. + + This uses AMP conversion API for bucketing module to convert to a mixed precision module. + + $ python [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0 --model-prefix saved_rnn_model --load-epoch 12 --test --dtype float16 + ### Performance Note: diff --git a/example/rnn/bucketing/cudnn_rnn_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py index 66d5a55c02cb..38275ae3dfb8 100644 --- a/example/rnn/bucketing/cudnn_rnn_bucketing.py +++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py @@ -18,6 +18,7 @@ import numpy as np import mxnet as mx import argparse +from mxnet.contrib.amp import amp parser = argparse.ArgumentParser(description="Train RNN on Sherlock Holmes", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -67,6 +68,10 @@ help='dropout probability (1.0 - keep probability)') parser.add_argument('--rnntype', type=str, default='lstm', help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported') +parser.add_argument('--dtype', type=str, default='float32', + help='if float16 is provided AMP convert model' + 'is used to convert model to mixed precision model' + 'before running inference') #buckets = [32] buckets = [10, 20, 30, 40, 50, 60] @@ -234,12 +239,20 @@ def sym_gen(seq_len): context = contexts) model.bind(data_val.provide_data, data_val.provide_label, for_training=False) - # note here we load using SequentialRNNCell instead of FusedRNNCell. _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch) model.set_params(arg_params, aux_params) - model.score(data_val, mx.metric.Perplexity(invalid_label), - batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) + if args.dtype == "float32": + model.set_params(arg_params, aux_params) + model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) + else: + assert args.dtype == "float16", "Only float32 and float16 are supported currently" + model = amp.convert_bucketing_module(model, target_dtype="float16") + model.bind(data_val.provide_data, data_val.provide_label, + for_training=False) + model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(args.batch_size, 5)) if __name__ == '__main__': import logging diff --git a/make/maven/maven_darwin_mkl.mk b/make/maven/maven_darwin_mkl.mk index 9bf3fc46ce0b..a7f2bdb027d4 100644 --- a/make/maven/maven_darwin_mkl.mk +++ b/make/maven/maven_darwin_mkl.mk @@ -77,7 +77,7 @@ USE_CUDNN = 0 # CUDA_ARCH := # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -ENABLE_CUDA_RTC = 0 +USE_NVRTC = 0 # use openmp for parallelization USE_OPENMP = 0 diff --git a/make/maven/maven_linux_cu90mkl.mk b/make/maven/maven_linux_cu90mkl.mk index 00cfd5dfa31d..e9ba46509973 100644 --- a/make/maven/maven_linux_cu90mkl.mk +++ b/make/maven/maven_linux_cu90mkl.mk @@ -79,9 +79,8 @@ USE_NCCL = 1 # CUDA_ARCH := # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -ENABLE_CUDA_RTC = 1 - USE_NVTX=1 +USE_NVRTC = 1 # use openmp for parallelization USE_OPENMP = 1 diff --git a/make/maven/maven_linux_cu92mkl.mk b/make/maven/maven_linux_cu92mkl.mk index 6ac920a12ff0..caa1c59c01d5 100644 --- a/make/maven/maven_linux_cu92mkl.mk +++ b/make/maven/maven_linux_cu92mkl.mk @@ -79,9 +79,8 @@ USE_NCCL = 1 # CUDA_ARCH := # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -ENABLE_CUDA_RTC = 1 - USE_NVTX=1 +USE_NVRTC = 1 # use openmp for parallelization USE_OPENMP = 1 diff --git a/make/maven/maven_linux_mkl.mk b/make/maven/maven_linux_mkl.mk index 10aee5f35a46..3c8534a7e2aa 100644 --- a/make/maven/maven_linux_mkl.mk +++ b/make/maven/maven_linux_mkl.mk @@ -76,7 +76,7 @@ USE_CUDNN = 0 # CUDA_ARCH := # whether use cuda runtime compiling for writing kernels in native language (i.e. Python) -ENABLE_CUDA_RTC = 0 +USE_NVRTC = 0 # use openmp for parallelization USE_OPENMP = 1 diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py index ef2f7209d946..746a9a7f6d68 100755 --- a/python/mxnet/contrib/amp/amp.py +++ b/python/mxnet/contrib/amp/amp.py @@ -32,6 +32,7 @@ from ... import symbol from ...context import gpu from ...symbol import Symbol +from ...module import BucketingModule from ...symbol import contrib as symbol_contrib from ... import ndarray from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP @@ -672,6 +673,69 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None, ret.collect_params().load_dict(arg_dict, ctx=ctx) return ret +def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype_ops=None, + fp32_ops=None, conditional_fp32_ops=None, + excluded_sym_names=None, cast_optional_params=False): + """Given a bucketing module cast the symbols associated with the BucketingModule + and params if cast_optional_params is set. + bucketing_mod : BucketingModule instance + target_dtype : str + Currently only supports float16. The target dtype indicates to add cast layers + when possible so that lower precision computation can be leveraged. + target_dtype_ops : list of strs + Override the list of operator names casted to target_dtype. + If None, uses the framework's default list to be casted to target dtype. + fp32_ops : list of strs + Override the lists of operator names casted to FP32. + If None, uses the framework's default list to be casted to FP32. + widest_dtype_ops : list of strs + A list of op names provided by user which should run in widest precision among its inputs. + If None, uses the framework's default list of widest_precision_ops. + conditional_fp32_ops : list of (string, string, list of string) + Override the list of operators to be casted to FP32. + The format of the list is + (name of the function, name of the parameter, + list of values of the parameter that make the operator to be casted to + fp32) + excluded_sym_names : list of strs + A list of strings that represent the names of symbols that users want to exclude + from being executed in lower precision. + cast_optional_params : bool, default False + Whether to cast the arg_params and aux_params that don't require to be in FP16 + because of a cast layer following it, but will reduce the computation and memory + overhead of the model if casted. + """ + assert isinstance(bucketing_mod, BucketingModule), "module should be instance of bucketing module" + assert len(bucketing_mod._buckets) > 0, "Bucketing Module should not be empty" + + sym_dict = {} + assert bucketing_mod.params_initialized, \ + "bucketing_mod params should be initialized for mixed precision conversion" + arg_params, aux_params = bucketing_mod._curr_module._arg_params, bucketing_mod._curr_module._aux_params + for key, val in bucketing_mod._buckets.items(): + sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol, + arg_params, + aux_params, + target_dtype=target_dtype, + target_dtype_ops=target_dtype_ops, + fp32_ops=fp32_ops, + conditional_fp32_ops=conditional_fp32_ops, + excluded_sym_names=excluded_sym_names, + cast_optional_params=cast_optional_params) + result_mod = BucketingModule.load_dict(sym_dict, + sym_gen=bucketing_mod._sym_gen, + arg_params=result_arg_params, + aux_params=result_aux_params, + default_bucket_key=bucketing_mod._default_bucket_key, + logger=bucketing_mod.logger, + context=bucketing_mod._context, + work_load_list=bucketing_mod._work_load_list, + fixed_param_names=bucketing_mod._fixed_param_names, + state_names=bucketing_mod._state_names, + group2ctxs=bucketing_mod._group2ctxs, + compression_params=bucketing_mod._compression_params) + return result_mod + def list_fp16_ops(): """Get the default list of FP16 ops for AMP """ diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py b/python/mxnet/contrib/onnx/mx2onnx/export_model.py index f5e4c3b69e15..158c24b08fc6 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py @@ -37,7 +37,7 @@ def export_model(sym, params, input_shape, input_type=np.float32, """Exports the MXNet model file, passed as a parameter, into ONNX model. Accepts both symbol,parameter objects as well as json and params filepaths as input. Operator support and coverage - - https://cwiki.apache.org/confluence/display/MXNET/MXNet-ONNX+Integration + https://cwiki.apache.org/confluence/display/MXNET/ONNX+Operator+Coverage Parameters ---------- diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 1ff1ee04643f..9ec66e8cf286 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -423,6 +423,22 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params, remove_amp_ca logging.info('Saved checkpoint to \"%s\"', param_name) +def load_params(prefix, epoch): + """Load params from a file + """ + save_dict = nd.load("%s-%04d.params" % (prefix, epoch)) + arg_params = {} + aux_params = {} + if not save_dict: + logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch)) + for k, v in save_dict.items(): + tp, name = k.split(":", 1) + if tp == "arg": + arg_params[name] = v + if tp == "aux": + aux_params[name] = v + return (arg_params, aux_params) + def load_checkpoint(prefix, epoch): """Load model checkpoint from file. @@ -448,22 +464,7 @@ def load_checkpoint(prefix, epoch): - Parameters will be loaded from ``prefix-epoch.params``. """ symbol = sym.load('%s-symbol.json' % prefix) - save_dict = nd.load('%s-%04d.params' % (prefix, epoch)) - arg_params = {} - aux_params = {} - #load any params in the dict, skip if params are empty - if not save_dict: - logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch)) - else: - for k, v in save_dict.items(): - tp, name = k.split(':', 1) - if tp == 'arg': - arg_params[name] = v - elif tp == 'aux': - aux_params[name] = v - else: - logging.warning("Params file '%s' contains unknown param '%s'", - '%s-%04d.params' % (prefix, epoch), k) + arg_params, aux_params = load_params(prefix, epoch) return (symbol, arg_params, aux_params) from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py index 66c666659d0b..dcf2ad7b8e1e 100644 --- a/python/mxnet/module/bucketing_module.py +++ b/python/mxnet/module/bucketing_module.py @@ -24,13 +24,17 @@ import logging import warnings +import numpy as np from .. import context as ctx from ..initializer import Uniform +from .. import ndarray as nd +from .. import symbol as sym from .base_module import BaseModule, _check_input_names from .module import Module +from ..model import load_params from ..name import NameManager class BucketingModule(BaseModule): @@ -170,7 +174,7 @@ def get_params(self): `(arg_params, aux_params)` A pair of dictionaries each mapping parameter names to NDArray values. """ - assert self.binded and self.params_initialized + assert self.params_initialized self._curr_module._params_dirty = self._params_dirty params = self._curr_module.get_params() self._params_dirty = False @@ -335,12 +339,16 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, self._grad_req = grad_req symbol, data_names, label_names = self._call_sym_gen(self._default_bucket_key) - module = Module(symbol, data_names, label_names, logger=self.logger, - context=self._context, work_load_list=self._work_load_list, - fixed_param_names=self._fixed_param_names, - state_names=self._state_names, - group2ctxs=self._group2ctxs, - compression_params=self._compression_params) + module = None + if not self._default_bucket_key in self._buckets: + module = Module(symbol, data_names, label_names, logger=self.logger, + context=self._context, work_load_list=self._work_load_list, + fixed_param_names=self._fixed_param_names, + state_names=self._state_names, + group2ctxs=self._group2ctxs, + compression_params=self._compression_params) + else: + module = self._buckets[self._default_bucket_key] module.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind=False, shared_module=None, grad_req=self._grad_req) self._curr_module = module @@ -380,6 +388,13 @@ def switch_bucket(self, bucket_key, data_shapes, label_shapes=None): if self._monitor is not None: module.install_monitor(self._monitor) self._buckets[bucket_key] = module + else: + module = self._buckets[bucket_key] + if not module.binded: + module.bind(data_shapes, label_shapes, self._curr_module.for_training, + self._curr_module.inputs_need_grad, + force_rebind=False, shared_module=self._buckets[self._default_bucket_key], + grad_req=self._grad_req) self._curr_module = self._buckets[bucket_key] self._curr_bucket_key = bucket_key @@ -544,3 +559,144 @@ def install_monitor(self, mon): self._monitor = mon for mod in self._buckets.values(): mod.install_monitor(mon) + + def save_checkpoint(self, prefix, epoch, remove_amp_cast=False): + """Saves current progress to checkpoint for all buckets in BucketingModule + Use `mx.callback.module_checkpoint` as `epoch_end_callback` to save during training. + + Parameters + ---------- + prefix : str + The file prefix to checkpoint to. + epoch : int + The current epoch number. + """ + + assert len(self._buckets) > 0, "Empty BucketingModule cannot be saved" + param_name = "%s-%04d.params" % (prefix, epoch) + self.save_params(param_name) + for bucket_key in self._buckets: + symbol, _, _ = self._sym_gen(bucket_key) + symbol.save("%s-%s-symbol.json" % (prefix, bucket_key), remove_amp_cast=remove_amp_cast) + nd.save("%s.buckets" % (prefix), nd.array(list(self._buckets.keys()), dtype=np.int32)) + + @staticmethod + def load(prefix, epoch, sym_gen=None, default_bucket_key=None, **kwargs): + """Creates a model from previously saved checkpoint. + + Parameters + ---------- + prefix : str + path prefix of saved model files. You should have + "prefix-symbol.json", "prefix-xxxx.params", and + optionally "prefix-xxxx.states", where xxxx is the + epoch number. + epoch : int + epoch to load. + sym_gen : function + A function when called with a bucket key, returns a triple + ``(symbol, data_names, label_names)``. + provide sym_gen which was used when saving bucketing module. + logger : Logger + Default is `logging`. + context : Context or list of Context + Default is ``cpu()``. + work_load_list : list of number + Default ``None``, indicating uniform workload. + fixed_param_names: list of str + Default ``None``, indicating no network parameters are fixed. + state_names : list of str + States are similar to data and label, but not provided by data iterator. + Instead they are initialized to 0 and can be set by set_states() + group2ctxs : dict of str to context or list of context, + or list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. + """ + assert sym_gen is not None, \ + "sym_gen is required for loading BucketingModule" + assert default_bucket_key is not None, \ + "default_bucket_key is required for loading BucketingModule" + buckets = nd.load("%s.buckets" % prefix) + buckets = list(buckets[0].asnumpy().astype('int32')) + bucketing_mod = BucketingModule(sym_gen, default_bucket_key, **kwargs) + for bucket_key in buckets: + _, data_names, label_names = sym_gen(bucket_key) + symbol = sym.load("%s-%s-symbol.json" % (prefix, bucket_key)) + bucketing_mod._buckets[bucket_key] = Module(symbol, data_names, label_names, **kwargs) + if bucket_key == default_bucket_key: + bucketing_mod._curr_module = bucketing_mod._buckets[bucket_key] + arg_params, aux_params = load_params(prefix, epoch) + bucketing_mod._curr_module._arg_params = arg_params + bucketing_mod._curr_module._aux_params = aux_params + bucketing_mod._curr_module.params_initialized = True + bucketing_mod.params_initialized = True + return bucketing_mod + + @staticmethod + def load_dict(sym_dict=None, sym_gen=None, default_bucket_key=None, arg_params=None, + aux_params=None, **kwargs): + """Creates a model from a dict mapping bucket_key to symbols and shared arg_params + and aux_params. + + Parameters + ---------- + sym_dict : dict mapping bucket_key to symbol + Dict mapping bucket key to symbol + sym_gen : function + A function when called with a bucket key, returns a triple + ``(symbol, data_names, label_names)``. + provide sym_gen which was used when saving bucketing module. + default_bucket_key : str (or any python object) + The key for the default bucket. + arg_params : dict + Required for loading the BucketingModule. + Dict of name to parameter ndarrays. + aux_params : dict + Required for loading the BucketingModule. + Dict of name to auxiliary state ndarrays. + logger : Logger + Default is `logging`. + context : Context or list of Context + Default is ``cpu()``. + work_load_list : list of number + Default ``None``, indicating uniform workload. + fixed_param_names: list of str + Default ``None``, indicating no network parameters are fixed. + state_names : list of str + States are similar to data and label, but not provided by data iterator. + Instead they are initialized to 0 and can be set by set_states() + group2ctxs : dict of str to context or list of context, + or list of dict of str to context + Default is `None`. Mapping the `ctx_group` attribute to the context assignment. + compression_params : dict + Specifies type of gradient compression and additional arguments depending + on the type of compression being used. For example, 2bit compression requires a threshold. + Arguments would then be {'type':'2bit', 'threshold':0.5} + See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. + """ + + assert sym_dict is not None, \ + "sym_dict needs to be provided for BucketingModule.load_dict" + assert arg_params is not None, \ + "arg_params need to be provided for BucketingModule.load_dict" + assert aux_params is not None, \ + "aux_params need to be provided for BucketingModule.load_dict" + assert default_bucket_key is not None, \ + "default_bucket_key needs to be provided for BucketingModule.load_dict" + + bucketing_mod = BucketingModule(sym_gen, default_bucket_key, **kwargs) + for bucket_key, loaded_sym in sym_dict.items(): + _, data_names, label_names = sym_gen(default_bucket_key) + bucketing_mod._buckets[bucket_key] = Module(loaded_sym, data_names, label_names, **kwargs) + if bucket_key == default_bucket_key: + bucketing_mod._curr_module = bucketing_mod._buckets[bucket_key] + bucketing_mod._curr_module._arg_params = arg_params + bucketing_mod._curr_module._aux_params = aux_params + bucketing_mod._curr_module.params_initialized = True + bucketing_mod.params_initialized = True + return bucketing_mod diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index c1867282e215..3ba141e94f62 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -250,7 +250,7 @@ def get_params(self): `(arg_params, aux_params)` A pair of dictionaries each mapping parameter names to NDArray values. """ - assert self.binded and self.params_initialized + assert self.params_initialized if self._params_dirty: self._sync_params_from_devices() diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index be918615bfd9..71707d41c8e8 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -20,7 +20,71 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['uniform'] + +__all__ = ['randint', 'uniform'] + + +def randint(low, high=None, size=None, dtype=None, **kwargs): + """Return random integers from `low` (inclusive) to `high` (exclusive). + + Return random integers from the "discrete uniform" distribution of + the specified dtype in the "half-open" interval [`low`, `high`). If + `high` is None (the default), then results are from [0, `low`). + + Parameters + ---------- + low : int + Lowest (signed) integer to be drawn from the distribution (unless + ``high=None``, in which case this parameter is one above the + *highest* such integer). + high : int, optional + If provided, one above the largest (signed) integer to be drawn + from the distribution (see above for behavior if ``high=None``). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.int'. + ctx : Context, optional + Device context of output. Default is current context. + out : ndarray, optional + The output ndarray (default is `None`). + + Returns + ------- + out : ndarray of ints + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + + Examples + -------- + >>> np.random.randint(2, size=10) + array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) + >>> np.random.randint(1, size=10) + array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + Generate a 2 x 4 array of ints between 0 and 4, inclusive: + + >>> np.random.randint(5, size=(2, 4)) + array([[4, 0, 2, 1], + [3, 2, 2, 0]]) + """ + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + if dtype is None: + dtype = 'int' + if ctx is None: + ctx = current_context() + if size is None: + size = 1 + if high is None: + high = low + low = 0 + return _npi.random_randint(low, high, shape=size, dtype=dtype, ctx=ctx, out=out) def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py index f19c1e03202f..b0683b439c2a 100644 --- a/python/mxnet/ndarray/random.py +++ b/python/mxnet/ndarray/random.py @@ -220,8 +220,8 @@ def randn(*shape, **kwargs): dtype = kwargs.pop('dtype', _Null) ctx = kwargs.pop('ctx', None) out = kwargs.pop('out', None) - assert isinstance(loc, (int, float)) - assert isinstance(scale, (int, float)) + assert isinstance(loc, (int, float, NDArray)) + assert isinstance(scale, (int, float, NDArray)) return _random_helper(_internal._random_normal, _internal._sample_normal, [loc, scale], shape, dtype, ctx, out, kwargs) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index f85936345b7f..f0fd43eb0e70 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -20,7 +20,60 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['uniform'] + +__all__ = ["randint", "uniform"] + + +def randint(low, high=None, size=None, dtype=None, **kwargs): + """Return random integers from `low` (inclusive) to `high` (exclusive). + + Return random integers from the "discrete uniform" distribution of + the specified dtype in the "half-open" interval [`low`, `high`). If + `high` is None (the default), then results are from [0, `low`). + + Parameters + ---------- + low : int + Lowest (signed) integer to be drawn from the distribution (unless + ``high=None``, in which case this parameter is one above the + *highest* such integer). + high : int, optional + If provided, one above the largest (signed) integer to be drawn + from the distribution (see above for behavior if ``high=None``). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.int'. + ctx : Context, optional + Device context of output. Default is current context. + out : ndarray, optional + The output ndarray (default is `None`). + + Returns + ------- + out : ndarray of ints + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + + Examples + -------- + >>> np.random.randint(2, size=10) + array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) + >>> np.random.randint(1, size=10) + array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + Generate a 2 x 4 array of ints between 0 and 4, inclusive: + + >>> np.random.randint(5, size=(2, 4)) + array([[4, 0, 2, 1], + [3, 2, 2, 0]]) + """ + return _mx_nd_np.random.randint(low, high, size, dtype, **kwargs) def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 338a5e28be4e..86d0ba3095e1 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -21,7 +21,71 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['uniform'] + +__all__ = ['randint', 'uniform'] + + +def randint(low, high=None, size=None, dtype=None, **kwargs): + """Return random integers from `low` (inclusive) to `high` (exclusive). + + Return random integers from the "discrete uniform" distribution of + the specified dtype in the "half-open" interval [`low`, `high`). If + `high` is None (the default), then results are from [0, `low`). + + Parameters + ---------- + low : int + Lowest (signed) integer to be drawn from the distribution (unless + ``high=None``, in which case this parameter is one above the + *highest* such integer). + high : int, optional + If provided, one above the largest (signed) integer to be drawn + from the distribution (see above for behavior if ``high=None``). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. All dtypes are determined by their + name, i.e., 'int64', 'int', etc, so byteorder is not available + and a specific precision may have different C types depending + on the platform. The default value is 'np.int'. + ctx : Context, optional + Device context of output. Default is current context. + out : symbol, optional + The output symbol (default is `None`). + + Returns + ------- + out : symbol + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + + Examples + -------- + >>> np.random.randint(2, size=10) + array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) + >>> np.random.randint(1, size=10) + array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + + Generate a 2 x 4 array of ints between 0 and 4, inclusive: + + >>> np.random.randint(5, size=(2, 4)) + array([[4, 0, 2, 1], + [3, 2, 2, 0]]) + """ + ctx = kwargs.pop('ctx', None) + out = kwargs.pop('out', None) + if dtype is None: + dtype = 'int' + if ctx is None: + ctx = current_context() + if size is None: + size = 1 + if high is None: + high = low + low = 0 + return _npi.random_randint(low, high, shape=size, dtype=dtype, ctx=ctx, out=out) def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py index 4bdfe7045625..b2ff104ff0f3 100644 --- a/python/mxnet/symbol/random.py +++ b/python/mxnet/symbol/random.py @@ -22,7 +22,7 @@ from .symbol import Symbol -__all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial', +__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'multinomial', 'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint'] @@ -113,6 +113,36 @@ def normal(loc=0, scale=1, shape=_Null, dtype=_Null, **kwargs): [loc, scale], shape, dtype, kwargs) +def randn(*shape, **kwargs): + """Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float or Symbol, optional + Mean (centre) of the distribution. + scale : float or Symbol, optional + Standard deviation (spread or width) of the distribution. + shape : int or tuple of ints + The number of samples to draw. If shape is, e.g., `(m, n)` and `loc` and + `scale` are scalars, output shape will be `(m, n)`. If `loc` and `scale` + are NDArrays with shape, e.g., `(x, y)`, then output will have shape + `(x, y, m, n)`, where `m*n` samples are drawn for each `[loc, scale)` pair. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + """ + loc = kwargs.pop('loc', 0) + scale = kwargs.pop('scale', 1) + dtype = kwargs.pop('dtype', _Null) + assert isinstance(loc, (int, float, Symbol)) + assert isinstance(scale, (int, float, Symbol)) + return _random_helper(_internal._random_normal, _internal._sample_normal, + [loc, scale], shape, dtype, kwargs) + + def poisson(lam=1, shape=_Null, dtype=_Null, **kwargs): """Draw random samples from a Poisson distribution. diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index c326091dbd9f..bb730fd3a007 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -262,6 +262,17 @@ def assign_each2(input1, input2, function): return output +# For testing Large Tensors having total size > 2^32 elements +def create_2d_tensor(rows, columns, dtype=np.int64): + a = mx.nd.arange(0, rows, dtype=dtype).reshape(rows, 1) + b = mx.nd.broadcast_to(a, shape=(a.shape[0], columns)) + return b + +# For testing Large Vectors having total size > 2^32 elements +def create_vector(size, dtype=np.int64): + a = mx.nd.arange(0, size, dtype=dtype) + return a + def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, data_init=None, rsp_indices=None, modifier_func=None, shuffle_csr_indices=False, ctx=None): diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 2d2bf2c64596..510ca29d7f91 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -132,14 +132,13 @@ class MKLDNNBNForward { return *var_m; } - void SetDataHandle(const NDArray &data, const mkldnn::memory *mean, + void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean, const mkldnn::memory *var, const mkldnn::memory *out) { - auto _data = data.GetMKLDNNData(); if (data_m) { - data_m->set_data_handle(_data->get_data_handle()); + data_m->set_data_handle(data->get_data_handle()); } else { - data_m.reset(new mkldnn::memory(_data->get_primitive_desc(), - _data->get_data_handle())); + data_m.reset(new mkldnn::memory(data->get_primitive_desc(), + data->get_data_handle())); } if (out_m) { out_m->set_data_handle(out->get_data_handle()); @@ -175,7 +174,7 @@ class MKLDNNBNForward { void SetDataHandle(const NDArray &data, const NDArray &mean, const NDArray &var, const mkldnn::memory &out) { - SetDataHandle(data, mean.GetMKLDNNData(), var.GetMKLDNNData(), &out); + SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out); } const mkldnn::batch_normalization_forward &GetFwd() const { diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc index df5e48744f2d..429a80e6b186 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc @@ -40,6 +40,27 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); const BatchNormParam ¶m = nnvm::get(attrs.parsed); const NDArray &data = in_data[quantized_batchnorm::kData]; + auto data_mem = data.GetMKLDNNData(); + + // reorder if data type = uint8 + if (in_data[quantized_batchnorm::kData].dtype() == mshadow::kUint8) { + auto u8_pd = data_mem->get_primitive_desc(); + auto u8_md = u8_pd.desc(); + mkldnn::memory::desc s8_md( + mkldnn::memory::dims(u8_md.data.dims, u8_md.data.dims + u8_md.data.ndims), + mkldnn::memory::data_type::s8, static_cast(u8_md.data.format)); + auto s8_pd = mkldnn::memory::primitive_desc(s8_md, CpuEngine::Get()->get_engine()); + auto data_reorder_mem = TmpMemMgr::Get()->Alloc(s8_pd); + + std::vector reorder_scale; + reorder_scale = {static_cast(kInt8Range) / kUint8Range}; + primitive_attr reorder_attr; + reorder_attr.set_int_output_round_mode(round_mode::round_nearest); + reorder_attr.set_output_scales(0, reorder_scale); + const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_pd, s8_pd, reorder_attr); + MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *data_mem, *data_reorder_mem)); + data_mem = data_reorder_mem; + } const size_t channelAxis = static_cast( param.axis < 0 ? static_cast(data.shape().ndim()) + param.axis : param.axis); const int channel_count = data.shape()[channelAxis]; @@ -92,7 +113,7 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut], fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data); - fwd.SetDataHandle(data, rescaled_mean_mem, rescaled_var_mem, out_mem.second); + fwd.SetDataHandle(data_mem, rescaled_mean_mem, rescaled_var_mem, out_mem.second); MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); MKLDNNStream::Get()->Submit(); diff --git a/src/operator/quantization/quantized_batch_norm.cc b/src/operator/quantization/quantized_batch_norm.cc index 3187826fe996..3c46e1b8bd5c 100644 --- a/src/operator/quantization/quantized_batch_norm.cc +++ b/src/operator/quantization/quantized_batch_norm.cc @@ -67,7 +67,13 @@ bool QuantizedBatchNormType(const nnvm::NodeAttrs& attrs, std::vector* in_t CHECK_EQ(in_type->size(), 7U); CHECK_EQ(out_type->size(), 3U); +#if MXNET_USE_MKLDNN == 1 + CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8) + << "QuantizedBatchNorm with MKLDNN backend only supports int8/uint8 input, while " + << in_type->at(0) << " is given."; +#else TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); +#endif for (size_t i = 1; i < 7; ++i) { TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); } diff --git a/src/operator/random/sample_multinomial_op.h b/src/operator/random/sample_multinomial_op.h index 5a0b9bb21acb..e5ecadaf0c24 100644 --- a/src/operator/random/sample_multinomial_op.h +++ b/src/operator/random/sample_multinomial_op.h @@ -70,11 +70,6 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs, const mxnet::TShape& ishape = (*in_attrs)[0]; if (!ndim_is_known(ishape)) return false; - MSHADOW_TYPE_SWITCH(param.dtype, DType, { - CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue()) - << "'dtype' does not have a sufficient precision to represent the indices of the input array."; - }); - if (ishape.ndim() == 1) { if (param.shape.ndim() > 0) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); @@ -121,7 +116,7 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs, struct SampleMultinomialKernel { template - MSHADOW_XINLINE static void Map(int i, index_t K, index_t M, + MSHADOW_XINLINE static void Map(index_t i, index_t K, index_t M, DType* dist, float* uniform, float* cum_table, IType* out, DType* prob) { double acc = 0.0; diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 543146257ddf..d5b5e288258c 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -182,6 +182,7 @@ Example:: MXNET_OPERATOR_REGISTER_SAMPLE(_random_randint, SampleRandIntParam) .add_alias("random_randint") +.add_alias("_npi_random_randint") .describe(R"code(Draw random samples from a discrete uniform distribution. Samples are uniformly distributed over the half-open interval *[low, high)* diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 80ab40ef6c50..db8676c028e4 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -117,9 +117,9 @@ class SoftmaxOutputOp : public Operator { CHECK_EQ(out_data.size(), 1U) << "SoftmaxOutput Output: [output]"; Stream *s = ctx.get_stream(); if (param_.multi_output) { - int n = in_data[softmaxout_enum::kData].size(0); - int k = in_data[softmaxout_enum::kData].size(1); - Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); + index_t n = in_data[softmaxout_enum::kData].size(0); + index_t k = in_data[softmaxout_enum::kData].size(1); + Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); Tensor data = in_data[softmaxout_enum::kData].get_with_shape(s3, s); Tensor out = @@ -131,8 +131,8 @@ class SoftmaxOutputOp : public Operator { Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); Softmax(out, data); } else { - int n = in_data[softmaxout_enum::kData].size(0); - int k = in_data[softmaxout_enum::kData].Size()/n; + index_t n = in_data[softmaxout_enum::kData].size(0); + index_t k = in_data[softmaxout_enum::kData].Size()/n; Shape<2> s2 = Shape2(n, k); Tensor data = in_data[softmaxout_enum::kData].get_with_shape(s2, s); @@ -171,9 +171,9 @@ class SoftmaxOutputOp : public Operator { grad = (out - label) * scalar(param_.grad_scale); } } else if (param_.multi_output) { - int n = out_data[softmaxout_enum::kOut].size(0); - int k = out_data[softmaxout_enum::kOut].size(1); - Shape<3> s3 = Shape3(n, k, static_cast(out_data[softmaxout_enum::kOut].Size()/n/k)); + index_t n = out_data[softmaxout_enum::kOut].size(0); + index_t k = out_data[softmaxout_enum::kOut].size(1); + Shape<3> s3 = Shape3(n, k, static_cast(out_data[softmaxout_enum::kOut].Size()/n/k)); Shape<2> s2 = Shape2(s3[0], s3[2]); Tensor label = in_data[softmaxout_enum::kLabel].get_with_shape(s2, s); @@ -224,7 +224,7 @@ class SoftmaxOutputOp : public Operator { // Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); // Tensor grad = in_grad[softmaxout_enum::kData].FlatTo2D(s); } else { - int n = out_data[softmaxout_enum::kOut].size(0); + index_t n = out_data[softmaxout_enum::kOut].size(0); data_shape = Shape2(n, out_data[softmaxout_enum::kOut].Size()/n); } Tensor label = in_data[softmaxout_enum::kLabel].get_with_shape( diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index b7776d648e18..be2716f139ed 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -426,15 +426,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, MKLDNNStream::Get()->Submit(); } else { std::vector new_inputs; - std::vector new_req; if (has_bias) { new_inputs = {data, cached_weight_, cached_bias_}; - new_req = {req[in_data], req[in_weight], req[in_bias]}; } else { new_inputs = {data, cached_weight_}; - new_req = {req[in_data], req[in_weight]}; } - MKLDNNConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), new_inputs, new_req, + MKLDNNConvolutionForwardFullFeature(full_conv_param, ctx, fwd_.get(), new_inputs, req, {output}); } diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index 22432193dcf9..6ebbb79d2571 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -188,16 +188,13 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, initialized_ = true; } std::vector new_inputs; - std::vector new_req; if (has_bias) { new_inputs = {data, weight, cached_bias_}; - new_req = {req[fullc::kData], req[fullc::kWeight], req[fullc::kBias]}; } else { new_inputs = {data, weight}; - new_req = {req[fullc::kData], req[fullc::kWeight]}; } - MKLDNNFCForwardFullFeature(full_param_, ctx, fwd_.get(), new_inputs, new_req, out_data); + MKLDNNFCForwardFullFeature(full_param_, ctx, fwd_.get(), new_inputs, req, out_data); if (mkldnn_param.quantized && !mkldnn_param.enable_float_output) { float *min_output_ptr = out_data[quantized_fullc::kOutMin].data().dptr(); diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 172665fbbf12..4e8900be24ca 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -129,7 +129,7 @@ Examples:: .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("FInferShape", RangeLikeShape) -.set_attr("FInferType", InitType) +.set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 0); }) .set_attr("FCompute", RangeCompute) diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 51c84363489a..f3c405d7103c 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -179,7 +179,6 @@ struct RangeLikeParam : public dmlc::Parameter { double step; int repeat; std::string ctx; - int dtype; dmlc::optional axis; DMLC_DECLARE_PARAMETER(RangeLikeParam) { @@ -197,9 +196,6 @@ struct RangeLikeParam : public dmlc::Parameter { .set_default("") .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." "Only used for imperative calls."); - DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) - MXNET_ADD_ALL_TYPES - .describe("Target data type."); DMLC_DECLARE_FIELD(axis) .set_default(dmlc::optional()) .describe("Arange elements according to the size of a certain axis of input array." diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 611dd7287206..58a535353e10 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -732,8 +732,8 @@ inline void GetIndexRange(const mxnet::TShape& dshape, } inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, - const index_t i, const int b, - const int e, const int s, + const index_t i, const index_t b, + const index_t e, const index_t s, mxnet::TShape* oshape) { if (!mxnet::dim_size_is_known(dshape, i)) { (*oshape)[i] = -1; @@ -765,7 +765,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { - const int b = begin[i], e = end[i], s = step[i]; + const index_t b = begin[i], e = end[i], s = step[i]; SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape); } }) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 84ac94ed8921..32d2f9895c2d 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -19,7 +19,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -31,12 +31,6 @@ LARGE_SIZE = LARGE_X * SMALL_Y -def create_2d_tensor(rows, columns, dtype=np.int64): - a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1) - b = nd.broadcast_to(a, shape=(a.shape[0], columns)) - return nd.array(b, dtype=dtype) - - def test_gluon_embedding(): m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) m.initialize() diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 50fac80d680d..d3069bb06866 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -19,7 +19,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -38,9 +38,144 @@ def create_large_vector(size, dtype="int64"): def test_slice(): a = nd.ones(LARGE_X) res = nd.slice(a, begin=(LARGE_X - MEDIUM_X), end=LARGE_X) + assert a[0] == 1 assert res.shape[0] == MEDIUM_X +def test_ndarray_zeros(): + a = nd.zeros(shape=LARGE_X) + assert a[-1] == 0 + assert a.shape == (LARGE_X,) + assert a.size == LARGE_X + + +def test_ndarray_ones(): + a = nd.ones(shape=LARGE_X) + assert a[-1] == 1 + assert nd.sum(a).asnumpy() == LARGE_X + + +@with_seed() +def test_ndarray_random_uniform(): + a = nd.random.uniform(shape=LARGE_X) + assert a[-1] != 0 + + +@with_seed() +def test_ndarray_random_randint(): + a = nd.random.randint(100, 10000, shape=LARGE_X) + assert a.shape == (LARGE_X,) + # check if randint can generate value greater than 2**32 (large) + low_large_value = 2**32 + high_large_value = 2**34 + a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64) + low = mx.nd.array([low_large_value], dtype='int64') + high = mx.nd.array([high_large_value], dtype='int64') + assert a > low and a < high + + +def test_ndarray_empty(): + a = nd.empty(LARGE_X) + assert a.shape == (LARGE_X,) + + +def test_elementwise(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) + res = a + b + assert res[-1].asnumpy() == 2 + res = a + 1 + assert res[-1].asnumpy() == 2 + res = nd.sqrt(a + 8) + assert res[-1].asnumpy() == 3 + + +def test_reduce(): + a = nd.ones(shape=(LARGE_X, 1)) + assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] + + +def test_clip(): + a = create_vector(LARGE_X) + res = nd.clip(a, a_min=100, a_max=1000) + assert np.sum(res[-1].asnumpy() == 1000) == 1 + + +def test_argmin(): + a = create_vector(LARGE_X, dtype=np.float32) + assert a[0] == 0 + idx = mx.nd.argmin(a, axis=0) + assert idx[0] == 0 + assert idx.shape[0] == 1 + + +def test_take(): + a = nd.ones(shape=LARGE_X) + idx = nd.arange(LARGE_X - 1000, LARGE_X) + res = nd.take(a, idx) + assert np.sum(res.asnumpy() == 1) == res.shape[0] + + +def test_slice_assign(): + a = nd.ones(shape=LARGE_X) + a[LARGE_X-1:LARGE_X] = 1000 + assert np.sum(a[-1].asnumpy() == 1000) == 1 + + +def test_expand_dims(): + a = nd.ones(shape=LARGE_X) + res = nd.expand_dims(a, axis=0) + assert res[0][0] == 1 + assert res.shape == (1, a.shape[0]) + + +def test_squeeze(): + a = nd.ones(shape=LARGE_X) + data = nd.expand_dims(a, axis=0) + res = nd.squeeze(data) + assert a[0] == res[0] + assert res.shape == a.shape + + +def test_broadcast_div(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) * 2 + res = a / b + assert np.sum(res.asnumpy() == 0.5) == a.shape[0] + + +def test_Dense(ctx=mx.cpu(0)): + data = mx.nd.ones(shape=LARGE_X) + linear = gluon.nn.Dense(2) + linear.initialize(ctx=ctx) + res = linear(data) + res.wait_to_read() + assert res.shape == (LARGE_X, 2) + + +def test_argsort(): + b = create_vector(size=LARGE_X) + s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) + mx.nd.waitall() + assert (s[0].asnumpy() == (LARGE_X - 1)).all() + + +def test_sort(): + b = create_vector(size=LARGE_X) + s = nd.sort(b, axis=0, is_ascend=False) + assert np.sum(s[-1].asnumpy() == 0).all() + s = nd.sort(b, is_ascend=True) + assert np.sum(s[0].asnumpy() == 0).all() + + +def test_topk(): + b = create_vector(size=LARGE_X) + ind = nd.topk(b, k=10, axis=0, dtype=np.int64) + assert np.sum(ind.asnumpy() == (LARGE_X - 1)) == 1 + ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) + assert np.all(ind == val) + val = nd.topk(b, k=1, axis=0, dtype=np.int64, ret_typ="value") + assert val.sum() == (LARGE_X - 1) @with_seed() def test_ndarray_random_exponential(): a = nd.random.exponential(shape=LARGE_X) diff --git a/tests/python/gpu/test_contrib_amp.py b/tests/python/gpu/test_contrib_amp.py index 3daab0f7bb6a..401bfcad3494 100644 --- a/tests/python/gpu/test_contrib_amp.py +++ b/tests/python/gpu/test_contrib_amp.py @@ -19,6 +19,7 @@ import sys import mxnet as mx import numpy as np +from random import randint import warnings import collections import ctypes @@ -31,6 +32,8 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import with_seed, teardown, assert_raises_cudnn_not_satisfied +sys.path.insert(0, os.path.join(curr_path, '../train')) +from test_bucketing import train_model set_default_context(mx.gpu(0)) def test_amp_coverage(): @@ -301,10 +304,42 @@ def check_amp_convert_hybrid_block(): params = converted_model.collect_params() assert params["stage2_unit1_conv2_weight"].dtype == np.float16 + + def check_amp_convert_bucketing_module(): + model = train_model(context=mx.current_context()) + result_model = amp.convert_bucketing_module(model) + val_sent = [] + batch_size = 128 + invalid_label = -1 + num_sentence = 1000 + buckets = [5, 10, 20, 30, 40] + len_vocab = 50 + + for _ in range(num_sentence): + len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty + val_sentence = [] + for _ in range(len_sentence): + val_sentence.append(randint(1, len_vocab)) + val_sent.append(val_sentence) + + data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False) + result_model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(batch_size, 1)) + + # AMP conversion with cast_optional_params set to true + result_model = amp.convert_bucketing_module(model, cast_optional_params=True) + result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False) + result_model.score(data_val, mx.metric.Perplexity(invalid_label), + batch_end_callback=mx.callback.Speedometer(batch_size, 1)) + + with mx.Context(mx.gpu(0)): check_amp_convert_symbol() check_amp_convert_model() check_amp_convert_hybrid_block() + check_amp_convert_bucketing_module() @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f8d8b4496afc..db550e4254b7 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2331,6 +2331,22 @@ def test_math(): for op in ops: run_math(op, shape, dtype, check_value=check_value) +@with_seed() +def test_arange_like_dtype(): + dtypes = [np.float16, np.float32, np.float64] + + for t in dtypes: + x = mx.sym.Variable('x', dtype=t) + y = mx.sym.reshape(x, shape=(0, 0, -1)) + z = mx.sym.contrib.arange_like(y, axis=-1) + + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') + mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) + out = mod.forward(is_train=False) + for v in out: + assert v.dtype == t + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index a1c23fb23208..31bc1638b010 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -607,19 +607,21 @@ def get_mean_var(data): return mean, var def check_quantized_bn(data_shape, qdtype): - if qdtype == 'uint8': - print('skipped testing quantize_bn for uint8 since it is not supported yet') - return - elif is_test_for_native_cpu(): + if is_test_for_native_cpu(): print('skipped testing quantize_bn for native cpu since it is not supported yet') return elif is_test_for_gpu(): print('skipped testing quantize_bn for gpu since it is not supported yet') return - # qdtype = int8 - data_low = -127.0 - data_high = 127.0 + # qdtype = uint8 + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + # output type = int8 quantized_range = 127.0 # run fp32 bn data_sym = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') @@ -639,9 +641,6 @@ def check_quantized_bn(data_shape, qdtype): bn_fp32_exe.arg_dict[arg_names[2]][:] = beta bn_fp32_exe.aux_dict[aux_names[0]][:] = moving_mean bn_fp32_exe.aux_dict[aux_names[1]][:] = moving_var - min_data = mx.nd.min(data) - max_data = mx.nd.max(data) - data_range = mx.nd.maximum(mx.nd.abs(min_data), mx.nd.abs(max_data)) output= bn_fp32_exe.forward()[0] @@ -654,11 +653,12 @@ def check_quantized_bn(data_shape, qdtype): calib_data = NDArrayIter(data=data, batch_size=data_shape[0]) calib_data = DummyIter(calib_data) + # quantize bn with quantized_type = int8: MKLDNN BN only support int8 output qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=bn_fp32, arg_params=arg_params, aux_params=bn_fp32_exe.aux_dict, ctx=mx.current_context(), - quantized_dtype=qdtype, + quantized_dtype='int8', calib_mode='naive', calib_data=calib_data, num_calib_examples=20) @@ -668,13 +668,14 @@ def check_quantized_bn(data_shape, qdtype): mod.set_params(qarg_params, qaux_params) batch = mx.io.DataBatch([data], []) mod.forward(batch, is_train=False) - output_int8_to_fp32= mod.get_outputs()[0] + output_int8_to_fp32 = mod.get_outputs()[0] - assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=3) + assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=4) - check_quantized_bn((32, 512, 4, 4), 'int8') - check_quantized_bn((32, 1024, 8, 8), 'int8') - check_quantized_bn((32, 3, 224, 224), 'int8') + for qdtype in ['int8', 'uint8']: + check_quantized_bn((32, 512, 4, 4), qdtype) + check_quantized_bn((32, 1024, 8, 8), qdtype) + check_quantized_bn((32, 3, 224, 224), qdtype) @with_seed() def test_quantize_params(): @@ -918,16 +919,12 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N lshape_list.append(None) for s, dshape, lshape, name in zip(sym_list, dshape_list, lshape_list, name_list): - if qdtype == 'int8' and is_test_for_mkldnn() and name in ['sym1', 'sym2', 'sym3']: - print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet') - continue - elif qdtype == 'uint8' and is_test_for_mkldnn() and name in ['sym1']: - print('skipping test_quantize_model_with_forward for mkldnn cpu uint8 since it is not supported yet') - continue - elif qdtype == 'int8' and is_test_for_gpu() and name in ['sym1']: - print('skipped testing test_quantize_model_with_forward for gpu int8 since it is not supported yet') - continue - + if qdtype == 'int8' and name in ['sym1','sym2','sym3']: + print('mkldnn_quantized_conv op only supports uint8 as input type, skip test with int8.') + continue + if qdtype == 'uint8' and name in ['sym1']: + print('mkldnn_quantized_bn doesn\'t support calib_mode=None') + continue if lshape is None: mod = Module(symbol=s, label_names=None) mod.bind(for_training=False, diff --git a/tests/python/train/test_bucketing.py b/tests/python/train/test_bucketing.py index 882c4a4a513d..a233e46e0992 100644 --- a/tests/python/train/test_bucketing.py +++ b/tests/python/train/test_bucketing.py @@ -20,9 +20,32 @@ import mxnet as mx import random from random import randint +from mxnet.contrib.amp import amp -def test_bucket_module(): +def prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence): + train_sent = [] + val_sent = [] + + for _ in range(num_sentence): + len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty + train_sentence = [] + val_sentence = [] + for _ in range(len_sentence): + train_sentence.append(randint(1, len_vocab)) + val_sentence.append(randint(1, len_vocab)) + train_sent.append(train_sentence) + val_sent.append(val_sentence) + + data_train = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, + invalid_label=invalid_label) + + return (data_train, data_val) + + +def train_model(context=mx.cpu()): import logging head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) @@ -41,23 +64,7 @@ def test_bucket_module(): invalid_label = -1 num_sentence = 1000 - train_sent = [] - val_sent = [] - - for _ in range(num_sentence): - len_sentence = randint(6, max(buckets)-1) # leave out the two last buckets empty - train_sentence = [] - val_sentence = [] - for _ in range(len_sentence): - train_sentence.append(randint(1, len_vocab)) - val_sentence.append(randint(1, len_vocab)) - train_sent.append(train_sentence) - val_sent.append(val_sentence) - - data_train = mx.rnn.BucketSentenceIter(train_sent, batch_size, buckets=buckets, - invalid_label=invalid_label) - data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, - invalid_label=invalid_label) + data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence) stack = mx.rnn.SequentialRNNCell() for i in range(num_layers): @@ -80,7 +87,7 @@ def sym_gen(seq_len): return loss, ('data',), ('softmax_label',) - contexts = mx.cpu(0) + contexts = context model = mx.mod.BucketingModule( sym_gen=sym_gen, @@ -101,9 +108,14 @@ def sym_gen(seq_len): num_epoch=num_epochs, batch_end_callback=mx.callback.Speedometer(batch_size, 50)) logging.info('Finished fit...') + return model + + +def test_bucket_module(): # This test forecasts random sequence of words to check bucketing. # We cannot guarantee the accuracy of such an impossible task, and comments out the following line. # assert model.score(data_val, mx.metric.MSE())[0][1] < 350, "High mean square error." + model = train_model() if __name__ == "__main__": diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index c82afdfe033a..b82933126d67 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import os import mxnet as mx import mxnet.ndarray as nd from mxnet.test_utils import * @@ -23,6 +24,9 @@ from mxnet.module.executor_group import DataParallelExecutorGroup from common import setup_module, with_seed, assertRaises, teardown from collections import namedtuple +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, "../train")) +from test_bucketing import train_model, prepare_bucketing_data @with_seed() @@ -216,6 +220,73 @@ def dict_equ(a, b): os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) +@with_seed() +def test_bucketing_save_load(): + previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") + os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') + def dict_equ(a, b): + assert set(a) == set(b) + for k in a: + assert (a[k].asnumpy() == b[k].asnumpy()).all() + + + len_vocab = 50 + num_embed = 25 + num_epochs = 5 + batch_size = 128 + num_layers = 2 + num_hidden = 25 + buckets = [5, 10, 20, 30, 40] + invalid_label = -1 + num_sentence=1000 + + stack = mx.rnn.SequentialRNNCell() + for i in range(num_layers): + stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i)) + + def sym_gen(seq_len): + data = mx.sym.Variable('data') + label = mx.sym.Variable('softmax_label') + embed = mx.sym.Embedding(data=data, input_dim=len_vocab, + output_dim=num_embed, name='embed') + stack.reset() + outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) + + pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) + pred = mx.sym.FullyConnected(data=pred, num_hidden=len_vocab, name='pred') + + label = mx.sym.Reshape(label, shape=(-1,)) + loss = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') + + return loss, ('data',), ('softmax_label',) + + model = train_model(context=mx.current_context()) + model.save_checkpoint("test", 0) + data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence) + mod2 = mx.mod.BucketingModule.load('test', 0, sym_gen=sym_gen, + default_bucket_key=data_train.default_bucket_key) + + mod2.bind(data_shapes=data_train.provide_data, + label_shapes=data_train.provide_label) + + for bucket_key in model._buckets.keys(): + dict_equ(model._buckets[model._default_bucket_key].get_params()[0], + mod2._buckets[mod2._default_bucket_key].get_params()[0]) + mod2.fit( + train_data=data_train, + eval_data=data_val, + eval_metric=mx.metric.Perplexity(invalid_label), # Use Perplexity for multiclass classification. + kvstore='device', + optimizer='sgd', + optimizer_params={'learning_rate': 0.01, + 'momentum': 0, + 'wd': 0.00001}, + initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), + num_epoch=num_epochs, + batch_end_callback=mx.callback.Speedometer(batch_size, 50)) + os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) + + @with_seed() def test_module_reshape(): data = mx.sym.Variable('data') diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 399cdead6177..0a1a4fb2b9b1 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -28,6 +28,8 @@ from common import assertRaises, with_seed import random import collections +import scipy.stats as ss +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry @with_seed() @@ -1080,6 +1082,50 @@ def hybrid_forward(self, F, a, *args): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@use_np +def test_np_randint(): + ctx = mx.context.current_context() + # test shapes + params = [ + (0, 10), + (5, None) + ] + shapes = [ + (3, 3), + (3, 4), + (0, 0), + (3, 3, 3), + (0, 0, 0), + (2, 2, 4, 3), + (2, 2, 4, 3), + (2, 0, 3, 0), + (2, 0, 2, 3) + ] + for shape in shapes: + for (low, high) in params: + data_mx = np.random.randint(low, high, size=shape) + assert data_mx.shape == shape + + # test generator + for dtype in ['int32', 'int64']: + for low, high in [(50000000, 50001000),(-50000100,-50000000),(-500,199)]: + scale = high - low + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), 5) + # Quantize bucket boundaries to reflect the actual dtype and adjust probs accordingly + buckets = _np.array(buckets, dtype=dtype).tolist() + probs = [(buckets[i][1] - buckets[i][0]) / float(scale) for i in range(5)] + generator_mx = lambda x: np.random.randint(low, high, size=x, dtype=dtype, ctx=ctx).asnumpy() + verify_generator(generator=generator_mx, buckets=buckets, probs=probs, nrepeat=100) + # Scipy uses alpha = 0.01 for testing discrete distribution generator but we are using default alpha=0.05 (higher threshold ensures robustness) + # Refer - https://github.com/scipy/scipy/blob/9f12af697763fb5f9767d5cb1280ce62456a3974/scipy/stats/tests/test_discrete_basic.py#L45 + generator_mx_same_seed = \ + lambda x: _np.concatenate( + [np.random.randint(low, high, size=x // 10, dtype=dtype, ctx=ctx).asnumpy() + for _ in range(10)]) + verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nrepeat=100) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 4de7d16e6f61..ceee51a3e503 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6851,6 +6851,7 @@ def test_laop_5(): # Tests for linalg.inverse @with_seed() +@unittest.skip("Test crashes https://github.com/apache/incubator-mxnet/issues/15975") def test_laop_6(): dtype = np.float64 rtol_fw = 1e-7 diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 720c25d2711e..fe276685bfe3 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -61,8 +61,10 @@ def check_with_device(device, dtype): }, { 'name': 'randn', + 'symbol': mx.sym.random.randn, 'ndop': mx.nd.random.randn, 'params': { 'loc': 10.0, 'scale': 0.5 }, + 'inputs': [ ('loc',[ [ 0.0, 2.5 ], [ -9.75, -7.0 ] ]) , ('scale',[ [ 1.0, 3.7 ], [ 4.2, 1.5 ] ]) ], 'checks': [ ('mean', lambda x, params: np.mean(x.astype(np.float64) - params['loc']), tol), ('std', lambda x, params: np.std(x.astype(np.float64)) - params['scale'], tol) @@ -250,6 +252,9 @@ def check_with_device(device, dtype): params = {'shape': shape, 'dtype': dtype, 'ctx': device} params.update({k : mx.nd.array(v, ctx=device, dtype=dtype) for k, v in symbdic['inputs']}) + if name == 'randn': + params.pop('shape') # randn does not accept shape param + args = shape mx.random.seed(128) ret1 = ndop(*args, **params).asnumpy() mx.random.seed(128) @@ -263,14 +268,12 @@ def check_with_device(device, dtype): err = np.abs(check_func(ret2[i,j], stats)) assert err < tol, "%f vs %f: symbolic test: %s check for `%s` did not pass" % (err, tol, check_name, name) - if 'symbol' not in symbdic: continue # randn does not have symbol - # check symbolic symbol = symbdic['symbol'] X = mx.sym.Variable("X") params = symbdic['params'].copy() params.update(shape=shape, dtype=dtype) - if name.endswith('_like'): + if name.endswith('_like') or name == 'randn': params['data'] = mx.sym.ones(params.pop('shape')) Y = symbol(**params) + X x = mx.nd.zeros(shape, dtype=dtype, ctx=device) @@ -298,7 +301,12 @@ def check_with_device(device, dtype): single_param = len(symbdic['inputs']) == 1 v1 = mx.sym.Variable('v1') v2 = mx.sym.Variable('v2') - Y = symbol(v1,**params) if single_param else symbol(v1,v2,**params) + if name == 'randn': + params.pop('shape') # randn does not accept shape param + args=shape + Y = symbol(v1, **params) if single_param else symbol(*args, loc=v1, scale=v2,**params) + else: + Y = symbol(v1,**params) if single_param else symbol(v1,v2,**params) bindings = { 'v1' : mx.nd.array(symbdic['inputs'][0][1]) } if not single_param : bindings.update({ 'v2' : mx.nd.array(symbdic['inputs'][1][1]) }) @@ -315,9 +323,10 @@ def check_with_device(device, dtype): for check_name, check_func, tol in symbdic['checks']: assert np.abs(check_func(samples, params)) < tol, "symbolic test: %s check for `%s` did not pass" % (check_name, name) + if 'pdfsymbol' not in symbdic: continue # randn not tested for pdf + # check pdfs with only a subset of the generated samples un1 = np.resize(un1, (un1.shape[0], un1.shape[1], pdfshape[0], pdfshape[1])) - print(name) symbol = symbdic['pdfsymbol'] pdffunc = symbdic['pdffunc'] v0 = mx.sym.Variable('v0') @@ -355,7 +364,6 @@ def check_with_device(device, dtype): check_symbolic_forward(test_pdf, [un1, p1, p2], [res], atol=forw_atol, rtol=forw_rtol, dtype=dtype) if dtype == np.float64: grad_nodes = ['v1', 'v2'] if symbdic['discrete'] else ['v0', 'v1', 'v2'] - print(backw_rtol) check_numeric_gradient(test_pdf, [un1, p1, p2], grad_nodes=grad_nodes, atol=backw_atol, rtol=backw_rtol, dtype=dtype) @with_seed(1000)