-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Mxnet-1397] Support symbolic api for requantize and dequantize #14749
Conversation
sym_min_range = mx.sym.Variable('min_range') | ||
sym_max_range = mx.sym.Variable('max_range') | ||
dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, | ||
sym_max_range, out_type='float32') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
sym_max_range = mx.sym.Variable('max_range') | ||
if min_calib_range is None or max_calib_range is None: | ||
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range) | ||
out = requant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ctx=mx.current_context()
so this test can cover both CPU and GPU computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
sym_max_range = mx.sym.Variable('max_range') | ||
dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, | ||
sym_max_range, out_type='float32') | ||
out = dequant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ctx=mx.current_context() so this test can cover both CPU and GPU computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, | ||
sym_max_range, out_type='float32') | ||
out = dequant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, | ||
'max_range':max_range}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
else: | ||
requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range, | ||
min_calib_range, max_calib_range) | ||
out = requant.bind(ctx=mx.cpu(), args={'data':qdata, 'min_range':min_range, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ctx=mx.current_context() so this test can cover both CPU and GPU computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your contribution! LGTM to me overall. I left a few small comments. Please resolve conflict with master and update PR.
@@ -84,6 +84,10 @@ by keep zero centered for the quantized value: | |||
.set_attr_parser(ParamParser<DequantizeParam>) | |||
.set_num_inputs(3) | |||
.set_num_outputs(1) | |||
.set_attr<nnvm::FListInputNames>("FListInputNames", | |||
[](const NodeAttrs& attrs) { | |||
return std::vector<std::string>{"data", "min_range", "max_range"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these names will be exposed to front end users, I hope they can align with other quantization operators. In quantized convolution and quantized FC, I see they are min_data
and max_data
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the names are documented well in most of the quantized ops I think it should be ok. Especially in quantized conv and FC there are too many quantized parameters, I think it is easier to understand the API with min_data
and max_data
@@ -61,6 +61,10 @@ inference accuracy. | |||
.set_attr_parser(ParamParser<RequantizeParam>) | |||
.set_num_inputs(3) | |||
.set_num_outputs(3) | |||
.set_attr<nnvm::FListInputNames>("FListInputNames", | |||
[](const NodeAttrs& attrs) { | |||
return std::vector<std::string>{"data", "min_range", "max_range"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution @shoubhik . LGTM.
…he#14749) * Adding support for symbolic API for requantize and dequantize * Adding name to contributors list * Removing redundant code * Addressing indentation and using current_context() instead of cpu() * merge from master * merge from master
Description
If I have a pre-qauntized model from another framework, e.g Tensorflow or Pyrtorch and want to use Mxnet as inference engine, I would like to be able to set the int8 weights, scales and shifts manullay instead of Mxnet converting a model for me. I would like to do so in a
nn.HybridBlock
. For example I can create a quantized convolution network as belowI can the later set the weight and ranges from my saved model. Currently
F.contrib.requantize(....)
andF.contrib.dequantize(....)
fail whenF
is a symbol, i.e, after callinghybradize()
on the network. For more detailed error example please look at the jira. In this CR I am enabalingrequantize
anddequantize
to be called from symbolic API.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
mx.sym.contrib.dequantize
to be called directly, tests, (and when applicable, API doc)mx.sym.contrib.requantize
to be called directly, tests, (and when applicable, API doc)Comments