Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move quantization to contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Feb 2, 2018
1 parent f339c19 commit 07e205c
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 27 deletions.
2 changes: 1 addition & 1 deletion example/quantization/imagenet_gen_qsym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import argparse
from common import modelzoo
import mxnet as mx
from mxnet.quantization import *
from mxnet.contrib.quantization import *


def download_calib_dataset(dataset_url, calib_dataset, logger=None):
Expand Down
4 changes: 1 addition & 3 deletions example/quantization/imagenet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import argparse
import mxnet as mx
import time
import os
import logging
from mxnet.quantization import *
from mxnet.contrib.quantization import *


def download_dataset(dataset_url, dataset_dir, logger=None):
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
from . import test_utils

from . import rnn
from . import quantization

from . import gluon

Expand Down
3 changes: 3 additions & 0 deletions python/mxnet/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@
from . import tensorboard

from . import text

from . import quantization
from . import quantization as quant
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
import ctypes
import logging
import os
from .base import _LIB, check_call
from .base import c_array, c_str, mx_uint, c_str_array
from .base import NDArrayHandle, SymbolHandle
from .symbol import Symbol, load
from . import ndarray as nd
from .ndarray import NDArray
from .io import DataIter
from .context import cpu, Context
from .module import Module
from ..base import _LIB, check_call
from ..base import c_array, c_str, mx_uint, c_str_array
from ..base import NDArrayHandle, SymbolHandle
from ..symbol import Symbol
from ..symbol import load as sym_load
from .. import ndarray
from ..ndarray import load as nd_load
from ..ndarray import NDArray
from ..io import DataIter
from ..context import cpu, Context
from ..module import Module


def _quantize_params(qsym, params):
Expand All @@ -55,8 +57,8 @@ def _quantize_params(qsym, params):
if name.endswith(('weight_quantize', 'bias_quantize')):
original_name = name[:-len('_quantize')]
param = params[original_name]
val, vmin, vmax = nd.contrib.quantize(data=param, min_range=nd.min(param),
max_range=nd.max(param), out_type='int8')
val, vmin, vmax = ndarray.contrib.quantize(data=param, min_range=ndarray.min(param),
max_range=ndarray.max(param), out_type='int8')
quantized_params[name] = val
quantized_params[name+'_min'] = vmin
quantized_params[name+'_max'] = vmax
Expand Down Expand Up @@ -139,8 +141,8 @@ def collect(self, name, ndarray):
return
handle = ctypes.cast(ndarray, NDArrayHandle)
ndarray = NDArray(handle, writable=False)
min_range = nd.min(ndarray).asscalar()
max_range = nd.max(ndarray).asscalar()
min_range = ndarray.min(ndarray).asscalar()
max_range = ndarray.max(ndarray).asscalar()
if name in self.min_max_dict:
cur_min_max = self.min_max_dict[name]
self.min_max_dict[name] = (min(cur_min_max[0], min_range), max(cur_min_max[1], max_range))
Expand Down Expand Up @@ -335,7 +337,7 @@ def _load_sym(sym, logger=logging):
cur_path = os.path.dirname(os.path.realpath(__file__))
symbol_file_path = os.path.join(cur_path, sym)
logger.info('Loading symbol from file %s' % symbol_file_path)
return load(symbol_file_path)
return sym_load(symbol_file_path)
elif isinstance(sym, Symbol):
return sym
else:
Expand All @@ -351,7 +353,7 @@ def _load_params(params, logger=logging):
cur_path = os.path.dirname(os.path.realpath(__file__))
param_file_path = os.path.join(cur_path, params)
logger.info('Loading params from file %s' % param_file_path)
save_dict = nd.load(param_file_path)
save_dict = nd_load(param_file_path)
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np

from . import io
from . import nd
from . import ndarray as nd
from . import symbol as sym
from . import optimizer as opt
from . import metric
Expand Down
12 changes: 6 additions & 6 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
import mxnet as mx
import numpy as np
from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same, set_default_context
from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same


def test_quantize_float32_to_int8():
Expand Down Expand Up @@ -313,8 +313,8 @@ def test_quantize_params():
params = {}
for name in offline_params:
params[name] = mx.nd.uniform(shape=(2, 2))
qsym = mx.quantization._quantize_symbol(sym, offline_params=offline_params)
qparams = mx.quantization._quantize_params(qsym, params)
qsym = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params)
qparams = mx.contrib.quant._quantize_params(qsym, params)
param_names = params.keys()
qparam_names = qparams.keys()
for name in qparam_names:
Expand All @@ -336,12 +336,12 @@ def test_quantize_sym_with_calib():
out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
offline_params = [name for name in sym.list_arguments()
if not name.startswith('data') and not name.endswith('label')]
qsym = mx.quantization._quantize_symbol(sym, offline_params=offline_params)
qsym = mx.contrib.quant._quantize_symbol(sym, offline_params=offline_params)
requantize_op_names = ['requantize_conv', 'requantize_fc']
th_dict = {'conv_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0)),
'fc_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0))}
op_name_to_th_name = {'requantize_conv': 'conv_output', 'requantize_fc': 'fc_output'}
cqsym = mx.quantization._calibrate_quantized_sym(qsym, th_dict)
cqsym = mx.contrib.quant._calibrate_quantized_sym(qsym, th_dict)
attr_dict = cqsym.attr_dict()
for name in requantize_op_names:
assert name in attr_dict
Expand All @@ -364,7 +364,7 @@ def get_threshold(nd):

nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23))}
expected_threshold = get_threshold(nd_dict['layer1'])
th_dict = mx.quantization._get_optimal_thresholds(nd_dict)
th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict)
assert 'layer1' in th_dict
assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold)

Expand Down

0 comments on commit 07e205c

Please sign in to comment.