Skip to content

Commit

Permalink
Move scale handling from thresholds to quantize (apache#10)
Browse files Browse the repository at this point in the history
* Move scale handling from thresholds to quantize

* Add clip requantization

* minor

* Comments
  • Loading branch information
anijain2305 authored Oct 21, 2020
1 parent adc7a1a commit d916431
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 77 deletions.
88 changes: 43 additions & 45 deletions python/tvm/hago/_op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import functools
import numpy as np
import logging
from .base import to_scalar

RUNTIME_DEBUG = False

Expand Down Expand Up @@ -132,7 +133,7 @@ def check_overflow(data, in_dtype, output):
if 'float' in in_dtype:
# skip overflow check for float input dtype
data.copyto(output)
return
return

if not allclose(arr, data.asnumpy(), rtol=1e-03, atol=1.0):
logging.warning('overflow happens')
Expand Down Expand Up @@ -238,7 +239,7 @@ def product_scale(input_scales):

def identity_scale(input_scales):
input_scales = [scale.value for scale in input_scales]
scale0 = input_scales[0]
scale0 = input_scales[0]
for scale in input_scales:
assert math.isclose(scale, scale, rel_tol=1e-6)
return scale0
Expand All @@ -249,56 +250,14 @@ def identity_scale(input_scales):
register_infer_scale("layout_transform", identity_scale)
register_infer_scale("nn.pad", identity_scale)
register_infer_scale("nn.relu", identity_scale)
register_infer_scale("clip", identity_scale)
register_infer_scale("nn.max_pool2d", identity_scale)
register_infer_scale("nn.avg_pool2d", identity_scale)
register_infer_scale("nn.global_avg_pool2d", identity_scale)
register_infer_scale("nn.adaptive_avg_pool2d", identity_scale)
register_infer_scale("nn.batch_flatten", identity_scale)

# threshold rectify function registered for ops

def register_threshold_rectify(op_name, frectify=None, level=10):
return tvm.ir.register_op_attr(op_name, "FHagoRectify", frectify, level)

@register_threshold_rectify("add")
def unify_scale(input_bits, output_bits, input_thresholds, output_thresholds):
sign_bit = 1
# convert from tvm object to POD
ibits = [bit.value for bit in input_bits]
# FIXME - Uncommenting next line can cause failures when add is followed by a non-quantized op.
# Exmaple is add --> clip and model is MXNet mobilenetv2
# obits = [32 if bit is None else bit.value for bit in output_bits]
itholds = [thold.value for thold in input_thresholds]
otholds = [thold.value for thold in output_thresholds]

# choose scale of the one with max threshold
idx = np.argmax(itholds)
chosen_thold = itholds[idx]
chosen_bit = ibits[idx]
unified_scale = itholds[idx] / (2 ** (ibits[idx] - sign_bit))

print(' in bits : {}'.format(ibits))
# print(' out bits : {}'.format(obits))
print(' in tholds : {}'.format(', '.join(["{:.3f}".format(thold) for thold in itholds])))
print(' out tholds: {}'.format(', '.join(["{:.3f}".format(thold) for thold in otholds])))
print(' choose unifed scale {:.3e} for op add'.format(unified_scale))
new_tholds = []
for i, bit in enumerate(ibits):
# integer_range = 2 ** (bit - sign_bit) - 1
# thold = integer_range * unified_scale
thold = (2 ** (bit - chosen_bit)) * chosen_thold
print(' rectify threshold from {} to {} for op add'.format(itholds[i], thold))
new_tholds.append(thold)
for thold in otholds:
new_tholds.append(thold)

print(' new tholds: {}'.format(', '.join(["{:.3f}".format(thold) for thold in new_tholds])))

return new_tholds


# realize registration for ops

def register_realize(op_name, frealize=None, level=10):
return tvm.ir.register_op_attr(op_name, "FHagoRealize", frealize, level)

Expand Down Expand Up @@ -332,3 +291,42 @@ def realize_conv2d(node, in_types, out_types):
attrs_dict['out_dtype'] = DataType(out_types[0])
attrs = tvm.ir.make_node("relay.attrs.Conv2DAttrs", **attrs_dict)
return relay.Call(node.op, node.args, attrs, node.type_args)

@register_realize("clip")
def realize_clip(node, in_types, out_types):
data = node.args[0]
assert data.op.name == 'qnn.requantize'
scale, zero_point = data.args[3], data.args[4]
scale_val = to_scalar(scale)
zero_point_val = to_scalar(zero_point)
dtype = data.attrs.out_dtype

clip_min = node.attrs.a_min
clip_max = node.attrs.a_max

# Quantize a float value to an quantized integer value
quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)

# Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not
# beyond the dtype range.
qmin = float(tvm.tir.op.min_value(dtype).value)
qmax = float(tvm.tir.op.max_value(dtype).value)
return relay.clip(data,
a_min=max(qmin, quantize(clip_min)),
a_max=min(qmax, quantize(clip_max)))

def register_rectify_scale(op_name, frectify_scale=None, level=10):
return tvm.ir.register_op_attr(op_name, "FHagoRectifyScale", frectify_scale, level)

@register_rectify_scale("add")
def add_rectify_scale(args, old_in_scales, old_out_scales):
new_scale = old_out_scales[0] if old_out_scales[0] > old_out_scales[1] else old_out_scales[1]
return [new_scale, new_scale]

def return_input_scale(args, old_in_scales, old_out_scales):
# Skip the requantize before relu
return [old_in_scales[0]]

register_rectify_scale("nn.relu", return_input_scale)
register_rectify_scale("clip", return_input_scale)

1 change: 1 addition & 0 deletions python/tvm/hago/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def create_accelerator_description():
hardware.add_op_desc('concatenate', OpDesc(in_dtypes='float32', out_dtypes='float32'))

hardware.add_op_desc('nn.relu', OpDesc(in_dtypes='int32', out_dtypes='int32'))
hardware.add_op_desc('clip', OpDesc(in_dtypes='int32', out_dtypes='int32'))
hardware.add_op_desc('nn.avg_pool2d', OpDesc(in_dtypes='float32', out_dtypes='float32'))
# hardware.add_op_desc('nn.avg_pool2d', OpDesc(in_dtypes='int32', out_dtypes='int32'))
hardware.add_op_desc('nn.max_pool2d', OpDesc(in_dtypes='int32', out_dtypes='int32'))
Expand Down
77 changes: 56 additions & 21 deletions python/tvm/hago/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .topology import Topology

import tvm
from tvm.tir import expr
import sys
import math
import numpy as np
Expand Down Expand Up @@ -98,7 +99,7 @@ def fvisit(node):
print(' {0}'.format(desc))
descs[node2idx[node]] = desc
relay.analysis.post_order_visit(graph, fvisit)
return descs
return descs


def infer_quantized_dtypes(topology, constraints):
Expand Down Expand Up @@ -139,7 +140,7 @@ def fvisit(node):
eidx = edge2idx[edge]
assign_dtype(req_dtypes, eidx, dtype)
relay.analysis.post_order_visit(topology.graph, fvisit)
return prov_dtypes, req_dtypes
return prov_dtypes, req_dtypes


class CalibrationDataset(object):
Expand Down Expand Up @@ -203,7 +204,7 @@ def eval(self, bits, thresholds, dataset, ctx, target):
"""compile simulated model and run it on the dataset"""
# prepare parameters
internal_params, output_params = self.calculate_params(bits, thresholds)
param_map = {}
param_map = {}
for nodes, p in zip(self.internal_param_nodes, internal_params):
vals = [p.in_scale, p.out_scale, p.clip_min, p.clip_max]
for node, val in zip(nodes, vals):
Expand Down Expand Up @@ -313,7 +314,7 @@ def visit_function(self, fn):
assert isinstance(out_scale_node, relay.Var)
oshape = out_scale_node.type_annotation.shape
if oshape != ():
in_scale_shape = oshape
in_scale_shape = oshape

in_scale = relay.var('in_scale' + str(self._name_cnt), shape=in_scale_shape)
out_scale = relay.var('out_scale' + str(self._name_cnt), 'float32')
Expand Down Expand Up @@ -345,7 +346,7 @@ def calculate_params(self, bits, thresholds):
"""calculate parameters of simulated quantize op from bits and thresholds"""
graph, topology, constraints = self.graph, self.topology, self.constraints
sign_bit = 1
edge2idx = self._edge2idx
edge2idx = self._edge2idx
node2idx = self._node2idx
assert len(thresholds) == len(node2idx)
edge2bit = topology.build_edge_info(bits)
Expand Down Expand Up @@ -375,14 +376,48 @@ def infer_scale_for_node(node):
else:
# per channel scales
assert len(input_scales) == 2
lhs, rhs = input_scales
lhs, rhs = input_scales
# support conv2d now, so only do scale multiplication
scale = lhs * rhs
return scale

print('\ncalculate parameters')
def fvisit(node):
if isinstance(node, relay.Call):
in_scales = list()
in_dtypes = list()
out_dtypes = list()
out_scales = list()
for edge in list_in_edges(node):
src, _ = edge
eidx = edge2idx[edge]
in_scales.append(infer_scale_for_node(src))
in_dtypes.append(prov_dtypes[node2idx[src]])
out_dtypes.append(req_dtypes[eidx])
if 'float' in str(req_dtypes[eidx]):
out_scale = 1.0
else:
bit = edge2bit[(src, node)]
integer_range = 2 ** (bit - sign_bit)
thold = thresholds[node2idx[src]]
out_scale = thold / integer_range
out_scales.append(out_scale)

rectified_output_scales = None
frectify_scale = node.op.get_attr('FHagoRectifyScale')
if frectify_scale:
new_output_scales = frectify_scale(node.args,
in_scales,
out_scales)
rectified_output_scales = list()
for idx, scale in enumerate(new_output_scales):
if isinstance(scale, expr.FloatImm):
rectified_output_scales.append(scale.value)
else:
raise NotImplementedError()


arg_idx = 0
for edge in list_in_edges(node):
src, _ = edge
eidx = edge2idx[edge]
Expand All @@ -401,8 +436,11 @@ def fvisit(node):
else:
bit = edge2bit[(src, node)]
integer_range = 2 ** (bit - sign_bit)
thold = thresholds[node2idx[src]]
out_scale = thold / integer_range
if rectified_output_scales is not None:
out_scale = rectified_output_scales[arg_idx]
else:
thold = thresholds[node2idx[src]]
out_scale = thold / integer_range
clip_min = - (integer_range - 1)
clip_max = integer_range - 1
print(' bit={}, threshold={}'.format(bit, thold))
Expand All @@ -412,10 +450,11 @@ def fvisit(node):
in_dtype, out_dtype)
print(' {}'.format(param))
internal_params.append(param)
arg_idx += 1
return
if isinstance(node, relay.Function):
# handle output of function
assert isinstance(node.body, relay.Call)
# handle output of function
assert isinstance(node.body, relay.Call)
node = node.body
print('---------')
print("{} -> OUT".format(node_str(node, node2idx)))
Expand Down Expand Up @@ -445,7 +484,7 @@ def fvisit(node):
def bind_simulated_graph(self, bits, thresholds):
# prepare parameters
internal_params, output_params = self.calculate_params(bits, thresholds)
param_map = {}
param_map = {}
for nodes, p in zip(self.internal_param_nodes, internal_params):
vals = [p.in_scale, p.out_scale, p.clip_min, p.clip_max]
for node, val in zip(nodes, vals):
Expand Down Expand Up @@ -482,7 +521,7 @@ def visit_call(self, node):
cstr = self._constraints[nidx]
frealize = node.op.get_attr("FHagoRealize")
if frealize and cstr is not None:
in_dtypes = [str(cstr.in_dtype(i)) for i in range(node.op.num_inputs)]
in_dtypes = [str(cstr.in_dtype(i)) for i in range(node.op.num_inputs)]
out_dtypes = [str(cstr.out_dtype(0))]
new_node = frealize(new_node, in_dtypes, out_dtypes)
return new_node
Expand All @@ -498,7 +537,7 @@ def _realize_simulated_quantize(self, node):
out_dtype = attrs.out_dtype
print(' in_scale: {}'.format(in_scale))
print(' out_scale: {}'.format(out_scale))

if in_dtype == 'float32' and out_dtype == 'float32':
# do nothing
return data
Expand All @@ -518,10 +557,6 @@ def _realize_simulated_quantize(self, node):
data = relay.clip(data, clip_min, clip_max)
data = relay.cast(data, out_dtype)
return data
elif in_scale == out_scale and in_dtype == out_dtype:
# do nothing
# TODO(ziheng) whether to clip?
return data
else:
# requantize
dtype = in_dtype
Expand Down Expand Up @@ -552,10 +587,10 @@ def _transform_scale(self, data, in_scale, out_scale, dtype):

def use_shift(in_val, out_val):
# whether to use shift, consider floating point numeric error
in_cond, in_exp = exponent_based_two(in_val)
out_cond, out_exp = exponent_based_two(out_val)
in_cond, in_exp = exponent_based_two(in_val)
out_cond, out_exp = exponent_based_two(out_val)
if in_cond and out_cond:
return True, in_exp - out_exp
return True, in_exp - out_exp
return exponent_based_two(in_val / out_val)

factor = in_scale / out_scale
Expand Down Expand Up @@ -588,7 +623,7 @@ def use_shift(in_val, out_val):
raise ValueError
return out


class Quantizer(object):
def __init__(self, graph, hardware, topology, bits, thresholds):
self.original_graph = graph
Expand Down
8 changes: 2 additions & 6 deletions python/tvm/hago/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def fvisit_rectify(node):
return thresholds


def threshold_estimate(graph, topology, stats, bits=None, rectify=True):
def threshold_estimate(graph, topology, stats, bits=None):
print('calculating threshold...')
cfg = current_qconfig()
print('threshold method:')
Expand All @@ -84,9 +84,5 @@ def threshold_estimate(graph, topology, stats, bits=None, rectify=True):
else:
raise ValueError

print('before rectify, thresholds: {}'.format(thresholds))
if rectify:
assert bits is not None
thresholds = threshold_rectify(graph, topology, bits, thresholds)
print('after rectify, thresholds: {}'.format(thresholds))
print('thresholds: {}'.format(thresholds))
return thresholds
3 changes: 1 addition & 2 deletions tests/python/nightly/quantization/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def eval_acc(func, dataset, batch_fn, args, var_name, target='cuda', ctx=tvm.gpu
# Quantize helper
#################
def quantize_hago(mod, params, calib_dataset):
qconfig = hago.qconfig(skip_conv_layers=[0],
log_file='temp.log')
qconfig = hago.qconfig(log_file='temp.log')

with qconfig:
graph = hago.prerequisite_optimize(mod['main'], params=params)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/nightly/quantization/test_mxnet_hago.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tvm import hago
from mxnet import gluon

from common_hago import *
from common_utils import *


parser = argparse.ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion tests/python/nightly/quantization/test_pytorch_hago.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tvm import hago
from mxnet import gluon

from common_hago import *
from common_utils import *

parser = argparse.ArgumentParser()
parser.add_argument("--model", default="resnet50_v1", help="model to quantize")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/nightly/quantization/test_tf_hago.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tvm import hago
from mxnet import gluon

from common_hago import *
from common_utils import *

try:
# %tensorflow_version only exists in Colab.
Expand Down

0 comments on commit d916431

Please sign in to comment.