Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into lts_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya authored Aug 27, 2019
2 parents dd17bec + 0e71fbd commit 57a10ac
Show file tree
Hide file tree
Showing 47 changed files with 971 additions and 150 deletions.
3 changes: 2 additions & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ CMakeLists.txt @szha @pllarroy

# MXNet CI
dev_menu.py @pllarroy
/ci/ @pllarroy
/ci/ @pllarroy @marcoabreu
/docker/ @marcoabreu
/tests/ci_build/ @marcoabreu
Jenkinsfile @marcoabreu
.travis.yml @marcoabreu
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ List of Contributors
* [Rohit Srivastava](https://github.com/access2rohit)
* [Caner Turkmen](https://github.com/canerturkmen)
* [Disi A](https://github.com/adis300)
* [Vandana Kannan](https://github.com/vandanavk)

Label Bot
---------
Expand Down
2 changes: 1 addition & 1 deletion benchmark/opperf/utils/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _prepare_op_inputs(inputs, run_backward, dtype, ctx):
return args_list, kwargs_list


def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, kwargs_list, profiler):
def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, args_list, kwargs_list, profiler):
if profiler == 'native':
if run_backward:
benchmark_helper_func = cpp_profile(nd_forward_backward_and_profile)
Expand Down
17 changes: 14 additions & 3 deletions benchmark/opperf/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def _prepare_op_benchmark_result(op, op_bench_result, profiler):
max_mem_usage = "---"
inputs = "---"
avg_time = "---"
p50_time = "---"
p90_time = "---"
p99_time = "---"

for key, value in op_bench_result.items():
if "avg_time_forward" in key:
avg_forward_time = value
Expand All @@ -108,12 +112,19 @@ def _prepare_op_benchmark_result(op, op_bench_result, profiler):
inputs = value
elif "avg_time" in key:
avg_time = value
elif "p50_time" in key:
p50_time = value
elif "p90_time" in key:
p90_time = value
elif "p99_time" in key:
p99_time = value

result = ""
if profiler == "native":
result = "| {} | {} | {} | {} | {} |".format(operator_name,
avg_forward_time, avg_backward_time, max_mem_usage, inputs)
elif profiler == "python":
result = "| {} | {} | {} |".format(operator_name, avg_time, inputs)
result = "| {} | {} | {} | {} | {} | {} |".format(operator_name, avg_time, p50_time, p90_time, p99_time, inputs)
return result


Expand All @@ -132,8 +143,8 @@ def _prepare_markdown(results, runtime_features=None, profiler='native'):
" | Inputs |")
elif profiler == 'python':
results_markdown.append(
"| Operator | Avg Time (ms) | Inputs |")
results_markdown.append("| :---: | :---: | :---: |")
"| Operator | Avg Time (ms) | P50 Time (ms) | P90 Time (ms) | P99 Time (ms) | Inputs |")
results_markdown.append("| :---: | :---: | :---: | :---: | :---: | :---: |")

for op, op_bench_results in sorted(results.items(), key=itemgetter(0)):
for op_bench_result in op_bench_results:
Expand Down
29 changes: 24 additions & 5 deletions benchmark/opperf/utils/profiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import time
import functools
import numpy as np

from .common_utils import merge_map_list
from mxnet import profiler
Expand Down Expand Up @@ -219,6 +220,9 @@ def python_profile(func):
res, timing output. res being result returned after operator execution.
profiler output is a dictionary with summary of operation execution.
Example output : { "add": [{"avg_time_add": 0.4053089120425284,
'p50_time_add': 16.761042876169086,
'p90_time_add': 18.081666342914108,
'p99_time_add': 19.060144051909447,
"inputs": {
"lhs": [1024, 1024],
"rhs": [1024,1024]
Expand All @@ -228,10 +232,16 @@ def python_profile(func):

@functools.wraps(func)
def python_profile_it(*args, **kwargs):
start_time = time.perf_counter() # 1
res = func(*args, **kwargs)
end_time = time.perf_counter() # 2
run_time = end_time - start_time # 3
runs = args[1]
modified_args = (args[0], 1, args[2])
times = []

for _ in range(runs):
start_time = time.perf_counter() # 1
res = func(*modified_args, **kwargs)
end_time = time.perf_counter() # 2
run_time = (end_time - start_time)*1000 # 3
times.append(run_time)

# NOTE : same as cpp_profile_it
if len(args) > 0:
Expand All @@ -241,6 +251,15 @@ def python_profile_it(*args, **kwargs):
else:
raise ValueError("Unable to identify operator name to extract profiler output!")

profiler_output = {'avg_time_'+str(operator_name): run_time}
avg_run_time = np.mean(times)
p50_run_time = np.percentile(times, 50)
p90_run_time = np.percentile(times, 90)
p99_run_time = np.percentile(times, 99)

profiler_output = {'avg_time_'+str(operator_name): avg_run_time,
'p50_time_'+str(operator_name): p50_run_time,
'p90_time_'+str(operator_name): p90_run_time,
'p99_time_'+str(operator_name): p99_run_time,
}
return res, profiler_output
return python_profile_it
2 changes: 1 addition & 1 deletion docs/api/python/contrib/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ With ONNX format support for MXNet, developers can build and train models with a
```

### Installation Instructions
- To use this module developers need to **install ONNX**, which requires the protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). **MXNet currently supports ONNX v1.2.1**. Once installed, you can go through the tutorials on how to use this module.
- To use this module developers need to **install ONNX**, which requires the protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). **MXNet currently supports ONNX v1.3**. Once installed, you can go through the tutorials on how to use this module.


This document describes all the ONNX-MXNet APIs.
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ Composite multiple symbols into a new one by an operator.
random.normal
random.poisson
random.randint
random.randn
random.shuffle
random.uniform
mxnet.random.seed
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/amp/amp_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ To do inference with mixed precision for a trained model in FP32, you can use th
Below, we demonstrate for a gluon model and a symbolic model:
- Conversion from FP32 model to mixed precision model.
- Run inference on the mixed precision model.
- For AMP conversion of bucketing module please refer to [example/rnn/bucketing/README.md](https://github.com/apache/incubator-mxnet/blob/master/example/rnn/bucketing/README.md).

```python
with mx.Context(mx.gpu(0)):
Expand Down Expand Up @@ -336,7 +337,6 @@ with mx.Context(mx.gpu(0)):
mod.save_checkpoint("amp_tutorial_model", 0, remove_amp_cast=False)
```


## Current limitations of AMP

- AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set
Expand Down
3 changes: 2 additions & 1 deletion example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
if exclude_first_conv:
excluded_sym_names += ['resnetv10_conv0_fwd']
elif args.model.find('resnet') != -1 and args.model.find('v2') != -1:
excluded_sym_names += ['resnetv20_flatten0_flatten0']
# resnetv20_stage1_batchnorm0_fwd is excluded for the sake of accuracy
excluded_sym_names += ['resnetv20_flatten0_flatten0', 'resnetv20_stage1_batchnorm0_fwd']
if exclude_first_conv:
excluded_sym_names += ['resnetv20_conv0_fwd']
elif args.model.find('vgg') != -1:
Expand Down
6 changes: 6 additions & 0 deletions example/rnn/bucketing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ You can check this improved [Gluon implementation](http://gluon-nlp.mxnet.io/mod

$ python3 [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0,1,2,3

- To run the mixed precision inference for the trained model, you should use the `--dtype`.

This uses AMP conversion API for bucketing module to convert to a mixed precision module.

$ python [cudnn_rnn_bucketing.py](cudnn_rnn_bucketing.py) --gpus 0 --model-prefix saved_rnn_model --load-epoch 12 --test --dtype float16


### Performance Note:

Expand Down
19 changes: 16 additions & 3 deletions example/rnn/bucketing/cudnn_rnn_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import mxnet as mx
import argparse
from mxnet.contrib.amp import amp

parser = argparse.ArgumentParser(description="Train RNN on Sherlock Holmes",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -67,6 +68,10 @@
help='dropout probability (1.0 - keep probability)')
parser.add_argument('--rnntype', type=str, default='lstm',
help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported')
parser.add_argument('--dtype', type=str, default='float32',
help='if float16 is provided AMP convert model'
'is used to convert model to mixed precision model'
'before running inference')

#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]
Expand Down Expand Up @@ -234,12 +239,20 @@ def sym_gen(seq_len):
context = contexts)
model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

# note here we load using SequentialRNNCell instead of FusedRNNCell.
_, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
model.set_params(arg_params, aux_params)

model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))
if args.dtype == "float32":
model.set_params(arg_params, aux_params)
model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))
else:
assert args.dtype == "float16", "Only float32 and float16 are supported currently"
model = amp.convert_bucketing_module(model, target_dtype="float16")
model.bind(data_val.provide_data, data_val.provide_label,
for_training=False)
model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

if __name__ == '__main__':
import logging
Expand Down
2 changes: 1 addition & 1 deletion make/maven/maven_darwin_mkl.mk
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ USE_CUDNN = 0
# CUDA_ARCH :=

# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
ENABLE_CUDA_RTC = 0
USE_NVRTC = 0

# use openmp for parallelization
USE_OPENMP = 0
Expand Down
3 changes: 1 addition & 2 deletions make/maven/maven_linux_cu90mkl.mk
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ USE_NCCL = 1
# CUDA_ARCH :=

# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
ENABLE_CUDA_RTC = 1

USE_NVTX=1
USE_NVRTC = 1

# use openmp for parallelization
USE_OPENMP = 1
Expand Down
3 changes: 1 addition & 2 deletions make/maven/maven_linux_cu92mkl.mk
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ USE_NCCL = 1
# CUDA_ARCH :=

# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
ENABLE_CUDA_RTC = 1

USE_NVTX=1
USE_NVRTC = 1

# use openmp for parallelization
USE_OPENMP = 1
Expand Down
2 changes: 1 addition & 1 deletion make/maven/maven_linux_mkl.mk
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ USE_CUDNN = 0
# CUDA_ARCH :=

# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
ENABLE_CUDA_RTC = 0
USE_NVRTC = 0

# use openmp for parallelization
USE_OPENMP = 1
Expand Down
64 changes: 64 additions & 0 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ... import symbol
from ...context import gpu
from ...symbol import Symbol
from ...module import BucketingModule
from ...symbol import contrib as symbol_contrib
from ... import ndarray
from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
Expand Down Expand Up @@ -672,6 +673,69 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
ret.collect_params().load_dict(arg_dict, ctx=ctx)
return ret

def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None,
excluded_sym_names=None, cast_optional_params=False):
"""Given a bucketing module cast the symbols associated with the BucketingModule
and params if cast_optional_params is set.
bucketing_mod : BucketingModule instance
target_dtype : str
Currently only supports float16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_dtype_ops : list of strs
Override the list of operator names casted to target_dtype.
If None, uses the framework's default list to be casted to target dtype.
fp32_ops : list of strs
Override the lists of operator names casted to FP32.
If None, uses the framework's default list to be casted to FP32.
widest_dtype_ops : list of strs
A list of op names provided by user which should run in widest precision among its inputs.
If None, uses the framework's default list of widest_precision_ops.
conditional_fp32_ops : list of (string, string, list of string)
Override the list of operators to be casted to FP32.
The format of the list is
(name of the function, name of the parameter,
list of values of the parameter that make the operator to be casted to
fp32)
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being executed in lower precision.
cast_optional_params : bool, default False
Whether to cast the arg_params and aux_params that don't require to be in FP16
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
assert isinstance(bucketing_mod, BucketingModule), "module should be instance of bucketing module"
assert len(bucketing_mod._buckets) > 0, "Bucketing Module should not be empty"

sym_dict = {}
assert bucketing_mod.params_initialized, \
"bucketing_mod params should be initialized for mixed precision conversion"
arg_params, aux_params = bucketing_mod._curr_module._arg_params, bucketing_mod._curr_module._aux_params
for key, val in bucketing_mod._buckets.items():
sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol,
arg_params,
aux_params,
target_dtype=target_dtype,
target_dtype_ops=target_dtype_ops,
fp32_ops=fp32_ops,
conditional_fp32_ops=conditional_fp32_ops,
excluded_sym_names=excluded_sym_names,
cast_optional_params=cast_optional_params)
result_mod = BucketingModule.load_dict(sym_dict,
sym_gen=bucketing_mod._sym_gen,
arg_params=result_arg_params,
aux_params=result_aux_params,
default_bucket_key=bucketing_mod._default_bucket_key,
logger=bucketing_mod.logger,
context=bucketing_mod._context,
work_load_list=bucketing_mod._work_load_list,
fixed_param_names=bucketing_mod._fixed_param_names,
state_names=bucketing_mod._state_names,
group2ctxs=bucketing_mod._group2ctxs,
compression_params=bucketing_mod._compression_params)
return result_mod

def list_fp16_ops():
"""Get the default list of FP16 ops for AMP
"""
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/mx2onnx/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def export_model(sym, params, input_shape, input_type=np.float32,
"""Exports the MXNet model file, passed as a parameter, into ONNX model.
Accepts both symbol,parameter objects as well as json and params filepaths as input.
Operator support and coverage -
https://cwiki.apache.org/confluence/display/MXNET/MXNet-ONNX+Integration
https://cwiki.apache.org/confluence/display/MXNET/ONNX+Operator+Coverage
Parameters
----------
Expand Down
33 changes: 17 additions & 16 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,22 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params, remove_amp_ca
logging.info('Saved checkpoint to \"%s\"', param_name)


def load_params(prefix, epoch):
"""Load params from a file
"""
save_dict = nd.load("%s-%04d.params" % (prefix, epoch))
arg_params = {}
aux_params = {}
if not save_dict:
logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch))
for k, v in save_dict.items():
tp, name = k.split(":", 1)
if tp == "arg":
arg_params[name] = v
if tp == "aux":
aux_params[name] = v
return (arg_params, aux_params)

def load_checkpoint(prefix, epoch):
"""Load model checkpoint from file.
Expand All @@ -448,22 +464,7 @@ def load_checkpoint(prefix, epoch):
- Parameters will be loaded from ``prefix-epoch.params``.
"""
symbol = sym.load('%s-symbol.json' % prefix)
save_dict = nd.load('%s-%04d.params' % (prefix, epoch))
arg_params = {}
aux_params = {}
#load any params in the dict, skip if params are empty
if not save_dict:
logging.warning("Params file '%s' is empty", '%s-%04d.params' % (prefix, epoch))
else:
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
elif tp == 'aux':
aux_params[name] = v
else:
logging.warning("Params file '%s' contains unknown param '%s'",
'%s-%04d.params' % (prefix, epoch), k)
arg_params, aux_params = load_params(prefix, epoch)
return (symbol, arg_params, aux_params)

from .callback import LogValidationMetricsCallback # pylint: disable=wrong-import-position
Expand Down
Loading

0 comments on commit 57a10ac

Please sign in to comment.