Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[REFACTOR] Asymmetric Quantization: deduplicate methods (#20514)
Browse files Browse the repository at this point in the history
* deduplicate Quantize->FC and FC->FC shifted quantization

* onednn->mkldnn

* review fixes

* review fix start LOG with uppercase

* review fix: unit test name (shifted->asymmetric)

* review fix by bgawrych, fixes sporadic test failures
  • Loading branch information
sfraczek authored Sep 2, 2021
1 parent 59e9b94 commit 42a48a7
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 97 deletions.
3 changes: 2 additions & 1 deletion python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,8 @@ def __exit__(self, exc_type, exc_value, traceback):
net.collect_params().load(param_name, cast_dtype=True, dtype_source='saved')
net.collect_params().reset_ctx(ctx)
if quantized_dtype == 'auto':
net.optimize_for(x=data_nd, backend="OneDNNShiftedQuantization")
mx.nd.waitall()
net.optimize_for(x=data_nd, backend="MKLDNNShiftedQuantization")
tmp_file = os.path.join(tmpdirname, 'model')
net.export(tmp_file)
net = SymbolBlock.imports(tmp_file + '-symbol.json', data_names, tmp_file + '-0000.params')
Expand Down
121 changes: 41 additions & 80 deletions src/operator/quantization/asymmetric_quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using nnvm::Graph;
using nnvm::ObjectPtr;

template <bool require_bias>
static bool IsOneDNNFullyConnected(const ObjectPtr& n) {
static bool IsMKLDNNFullyConnected(const ObjectPtr& n) {
if (n->op() == Op::Get("_sg_mkldnn_fully_connected")) {
auto const& param = nnvm::get<MKLDNNFCFullParam>(n->attrs.parsed);
FCInputIndex idx(param);
Expand Down Expand Up @@ -68,20 +68,19 @@ static NDArray* FindInArgByName(const Graph& g, const std::string& name) {
}

// Rescales weights, min_weight and max_weight. Returns bias_int32_rescale.
static float RescaleWeights(const Graph& g, const ObjectPtr& fc, NDArray* weight_tensor) {
FCInputIndex idx(nnvm::get<MKLDNNFCFullParam>(fc->attrs.parsed));
static float RescaleWeights(const Graph& g,
const ObjectPtr& fc,
NDArray* weight_tensor,
float min_data,
float max_data,
FCInputIndex idx) {
auto fc_input_node_name = [&fc](int input) { return fc->inputs[input].node->attrs.name; };

float* min_weight = FindInArgByName(g, fc_input_node_name(idx.weight_min))->data().dptr<float>();
float* max_weight = FindInArgByName(g, fc_input_node_name(idx.weight_max))->data().dptr<float>();
float min_bias = *FindInArgByName(g, fc_input_node_name(idx.bias_min))->data().dptr<float>();
float max_bias = *FindInArgByName(g, fc_input_node_name(idx.bias_max))->data().dptr<float>();

float* min_weight =
FindInArgByName(g, fc->inputs[idx.weight_min].node->attrs.name)->data().dptr<float>();
float* max_weight =
FindInArgByName(g, fc->inputs[idx.weight_max].node->attrs.name)->data().dptr<float>();
float min_bias =
*FindInArgByName(g, fc->inputs[idx.bias_min].node->attrs.name)->data().dptr<float>();
float max_bias =
*FindInArgByName(g, fc->inputs[idx.bias_max].node->attrs.name)->data().dptr<float>();

float min_data = std::stof(fc->inputs[idx.data].node->attrs.dict.at("min_calib_range"));
float max_data = std::stof(fc->inputs[idx.data].node->attrs.dict.at("max_calib_range"));
float data_scale_ = kUint8Range / (max_data - min_data);
float weight_scale = GetQuantizeScale(mshadow::kInt8, *min_weight, *max_weight);
float bias_scale = GetQuantizeScale(mshadow::kInt8, min_bias, max_bias);
Expand All @@ -92,9 +91,9 @@ static float RescaleWeights(const Graph& g, const ObjectPtr& fc, NDArray* weight
float bias_max_rescale =
mshadow::red::limits::MaxValue<int32_t>() / 2 / MaxAbs(min_bias, max_bias) / bias_scale;
if (bias_int32_rescale > bias_max_rescale) {
LOG(INFO) << "RESCALING WEIGHTS in shifted quantization because bias scale "
"is too big in layer "
<< fc->attrs.name;
LOG(INFO)
<< "RESCALING WEIGHTS in asymmetric quantization because bias scale is too big in layer "
<< fc->attrs.name;
// avoid overflow on bias
bias_int32_rescale = bias_max_rescale;
float weight_rescale = bias_int32_rescale * bias_scale / data_scale_ / weight_scale;
Expand Down Expand Up @@ -126,60 +125,23 @@ static void ShiftBias(int32_t* bias_ptr_int32,
enum class Pattern { QuantizeFc, FcFc, None };

static Pattern FindPattern(const ObjectPtr& node) {
if (IsOneDNNFullyConnected<true>(node)) {
if (IsMKLDNNFullyConnected<true>(node)) {
if (IsQuantize(node->inputs[0].node)) {
return Pattern::QuantizeFc;
} else if (IsOneDNNFullyConnected<false>(node->inputs[0].node)) {
} else if (IsMKLDNNFullyConnected<false>(node->inputs[0].node)) {
return Pattern::FcFc;
}
}
return Pattern::None;
}

static void QuantizeFcShiftedQuantization(const ObjectPtr& node,
Graph&& g,
std::vector<NDArray*>* new_arg_vector,
std::vector<std::string>* new_arg_names) {
ObjectPtr& quantize = node->inputs[0].node;
ObjectPtr& bias_node = node->inputs[2].node;
std::string bias_name_old = bias_node->attrs.name;
NDArray* bias_in_arg_ptr = FindInArgByName(g, bias_name_old);
if (bias_in_arg_ptr->dtype() != mshadow::kInt8)
return;
std::string bias_name_s32 = bias_node->attrs.name + "_s32";
bias_node = CreateNode("nullptr", bias_name_s32);
new_arg_names->push_back(bias_name_s32);

quantize->attrs.dict["shifted"] = "True";
if (quantize->op()->attr_parser)
quantize->op()->attr_parser(&(quantize->attrs));

NDArray* weight_tensor = FindInArgByName(g, node->inputs[1].node->attrs.name);

float bias_int32_rescale = RescaleWeights(g, node, weight_tensor);

new_arg_vector->push_back(new NDArray(
kDefaultStorage, bias_in_arg_ptr->shape(), Context::CPU(), false, mshadow::kInt32));
int32_t* bias_ptr_int32 = new_arg_vector->back()->data().dptr<int32_t>();
size_t bias_size = bias_in_arg_ptr->shape().Size();
int8_t* bias_ptr_old = bias_in_arg_ptr->data().dptr<int8_t>();

for (size_t i = 0; i < bias_size; ++i) {
bias_ptr_int32[i] = static_cast<int32_t>(std::round(bias_ptr_old[i] * bias_int32_rescale));
}
float min_data = std::stof(quantize->attrs.dict.at("min_calib_range"));
float max_data = std::stof(quantize->attrs.dict.at("max_calib_range"));
float data_scale = kUint8Range / (max_data - min_data);
int32_t shift_value = static_cast<int32_t>(std::round(data_scale * -min_data));
ShiftBias(bias_ptr_int32, bias_size, weight_tensor, shift_value);
}
static void FCShiftedQuantization(const ObjectPtr& node,
const Graph& g,
std::vector<NDArray*>* new_arg_vector,
std::vector<std::string>* new_arg_names) {
FCInputIndex idx(nnvm::get<MKLDNNFCFullParam>(node->attrs.parsed));

static void FcFcShiftedQuantization(const ObjectPtr& node,
Graph&& g,
std::vector<NDArray*>* new_arg_vector,
std::vector<std::string>* new_arg_names) {
ObjectPtr& first_fc = node->inputs[0].node;
ObjectPtr& bias_node = node->inputs[2].node;
ObjectPtr& bias_node = node->inputs[idx.bias].node;
std::string bias_name_old = bias_node->attrs.name;
NDArray* bias_in_arg_ptr = FindInArgByName(g, bias_name_old);
if (bias_in_arg_ptr->dtype() != mshadow::kInt8)
Expand All @@ -188,13 +150,15 @@ static void FcFcShiftedQuantization(const ObjectPtr& node,
bias_node = CreateNode("nullptr", bias_name_s32);
new_arg_names->push_back(bias_name_s32);

first_fc->attrs.dict["shifted_output"] = "True";
if (first_fc->op()->attr_parser)
first_fc->op()->attr_parser(&(first_fc->attrs));

NDArray* weight_tensor = FindInArgByName(g, node->inputs[1].node->attrs.name);
ObjectPtr& input_node = node->inputs[idx.data].node;
input_node->attrs.dict["shifted_output"] = "True";
if (input_node->op()->attr_parser)
input_node->op()->attr_parser(&(input_node->attrs));

float bias_int32_rescale = RescaleWeights(g, node, weight_tensor);
float min_data = std::stof(input_node->attrs.dict.at("min_calib_range"));
float max_data = std::stof(input_node->attrs.dict.at("max_calib_range"));
NDArray* weight_tensor = FindInArgByName(g, node->inputs[1].node->attrs.name);
float bias_int32_rescale = RescaleWeights(g, node, weight_tensor, min_data, max_data, idx);

new_arg_vector->push_back(new NDArray(
kDefaultStorage, bias_in_arg_ptr->shape(), Context::CPU(), false, mshadow::kInt32));
Expand All @@ -207,20 +171,18 @@ static void FcFcShiftedQuantization(const ObjectPtr& node,
bias_ptr_int32[i] = static_cast<int32_t>(std::round(bias_ptr_old[i] * bias_int32_rescale));
}

float min_data = std::stof(first_fc->attrs.dict.at("min_calib_range"));
float max_data = std::stof(first_fc->attrs.dict.at("max_calib_range"));
float data_scale = kUint8Range / (max_data - min_data);
int32_t shift_value = static_cast<int32_t>(std::round(data_scale * -min_data));
ShiftBias(bias_ptr_int32, bias_size, weight_tensor, shift_value);
}

static Graph OneDNNShiftedQuantization(Graph&& g) {
static Graph MKLDNNShiftedQuantization(Graph&& g) {
bool disable_shifted_quant =
dmlc::GetEnv("MXNET_DISABLE_SHIFTED_QUANTIZATION_OPTIMIZATIONS", true);
bool quantize_fc = !dmlc::GetEnv("MXNET_DISABLE_SHIFTED_QUANTIZE_FC_OPTIMIZATION", false);
bool fc_fc = !dmlc::GetEnv("MXNET_DISABLE_SHIFTED_FC_FC_OPTIMIZATION", false);
if (!disable_shifted_quant) {
LOG(INFO) << "Running OneDNN shifted quantization";
LOG(INFO) << "Running MKLDNN asymmetric quantization";
}
// No change to aux params
g.attrs["new_aux_names"] = std::make_shared<nnvm::any>(std::vector<std::string>());
Expand All @@ -238,14 +200,13 @@ static Graph OneDNNShiftedQuantization(Graph&& g) {
switch (p) {
case Pattern::QuantizeFc:
if (quantize_fc) {
QuantizeFcShiftedQuantization(
node, std::forward<Graph>(g), &new_arg_vector, &new_arg_names);
FCShiftedQuantization(node, g, &new_arg_vector, &new_arg_names);
++quantize_fc_counter;
}
break;
case Pattern::FcFc:
if (fc_fc) {
FcFcShiftedQuantization(node, std::forward<Graph>(g), &new_arg_vector, &new_arg_names);
FCShiftedQuantization(node, g, &new_arg_vector, &new_arg_names);
++fc_fc_counter;
}
break;
Expand All @@ -254,21 +215,21 @@ static Graph OneDNNShiftedQuantization(Graph&& g) {
}
});
if (quantize_fc_counter > 0) {
LOG(INFO) << "applied shifted quantization on QUANTIZE->FC " << quantize_fc_counter
LOG(INFO) << "Applied asymmetric quantization on QUANTIZE->FC " << quantize_fc_counter
<< " times";
}
if (fc_fc_counter > 0) {
LOG(INFO) << "applied shifted quantization on FC->FC " << fc_fc_counter << " times";
LOG(INFO) << "Applied asymmetric quantization on FC->FC " << fc_fc_counter << " times";
}
}
g.attrs["new_arg_names"] = std::make_shared<nnvm::any>(new_arg_names);
g.attrs["new_args"] = std::make_shared<nnvm::any>(new_arg_vector);
return g;
}

NNVM_REGISTER_PASS(OneDNNShiftedQuantization)
.describe("Enables shifted quantization.")
.set_body(OneDNNShiftedQuantization)
NNVM_REGISTER_PASS(MKLDNNShiftedQuantization)
.describe("Enables asymmetric quantization.")
.set_body(MKLDNNShiftedQuantization)
.set_change_graph(true);

} // namespace asym_quant
Expand Down
16 changes: 8 additions & 8 deletions src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext& ctx,
}

// Write output min/max
auto out_type = GetQuantizeOutputType(param_);
const bool shifted = param_.shifted.has_value() && param_.shifted.value();
if (shifted) {
// if shifted == true we have guarantee that data_min is negative because
// we require that in shifted quantization pass in quantize_graph_pass
auto out_type = GetQuantizeOutputType(param_);
const bool shifted_output = param_.shifted_output.has_value() && param_.shifted_output.value();
if (shifted_output) {
// if shifted_output == true we have guarantee that data_min is negative because
// we require that in asymmetric quantization pass in quantize_graph_pass
// Modify out min/max range to reflect shifted data
out_type = mshadow::kUint8;
*outputs[1].data().dptr<float>() = 0;
Expand All @@ -142,7 +142,7 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext& ctx,
if (!initalized_) {
cached_data_min_ = data_min;
cached_data_max_ = data_max;
if (shifted) {
if (shifted_output) {
CHECK_LT(data_min, 0); // assert that we are working on signed
cached_scale_ = kUint8Range / (data_max - data_min);
cached_shift_ = static_cast<uint8_t>(std::round(cached_scale_ * -cached_data_min_));
Expand All @@ -153,7 +153,7 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext& ctx,
const int mask = 0;
std::vector<float> scales = {cached_scale_};
attr.set_output_scales(mask, scales);
if (shifted) {
if (shifted_output) {
// TODO(sfraczek): change to zero point when optimized in oneDNN
dnnl::post_ops po;
po.append_sum();
Expand All @@ -180,7 +180,7 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext& ctx,
args_[MKLDNN_ARG_TO] = *o_mem.second;
MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
CommitOutput(outputs[0], o_mem);
if (shifted) {
if (shifted_output) {
uint8_t* raw_out_mem = static_cast<uint8_t*>(o_mem.second->get_data_handle());
std::fill_n(raw_out_mem, outputs[0].shape().Size(), cached_shift_);
}
Expand Down
7 changes: 4 additions & 3 deletions src/operator/quantization/quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
int out_type;
dmlc::optional<float> min_calib_range;
dmlc::optional<float> max_calib_range;
dmlc::optional<bool> shifted;
dmlc::optional<bool> shifted_output;
DMLC_DECLARE_PARAMETER(QuantizeV2Param) {
DMLC_DECLARE_FIELD(out_type)
.add_enum("auto", QuantizeOutType::qAuto)
Expand All @@ -58,7 +58,7 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32. If present, it will be used to "
"quantize the fp32 data into int8 or uint8.");
DMLC_DECLARE_FIELD(shifted)
DMLC_DECLARE_FIELD(shifted_output)
.set_default(dmlc::optional<bool>())
.describe("Whether quantization ouptut should be shifted.");
}
Expand Down Expand Up @@ -134,7 +134,8 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector<int>
CHECK(in_attrs->at(0) == mshadow::kFloat32 || in_attrs->at(0) == mshadow::kUint8 ||
in_attrs->at(0) == mshadow::kInt8);
auto out_type = GetQuantizeOutputType(param);
if (out_type == mshadow::kUint8 || (param.shifted.has_value() && param.shifted.value())) {
if (out_type == mshadow::kUint8 ||
(param.shifted_output.has_value() && param.shifted_output.value())) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
} else if (out_type == mshadow::kInt8) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt8);
Expand Down
10 changes: 5 additions & 5 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ def get_threshold(nd):


@with_seed()
def test_onednn_shifted_quantize_fc():
def test_mkldnn_asymmetric_quantize_fc():
batch_size = 1
if not is_test_for_mkldnn():
print("Test only for mkldnn")
Expand Down Expand Up @@ -1309,7 +1309,7 @@ def get_fc_layer():
fc_layer.initialize()
return fc_layer

# Shifted quantization should set new bias to FC and add shift to output of quantize
# Asymmetric quantization should set new bias to FC and add shift to output of quantize
# b'=b-shift*w because FC(x+shift,w,b)=(x+shift)*w+b
def check(number, qdtype):
random_data = mx.nd.random_uniform(low=0 if qdtype == 'uint8' else -1, high=1, shape=(batch_size, 32))
Expand All @@ -1333,13 +1333,13 @@ def check(number, qdtype):
assert_almost_equal_with_err(out_q.asnumpy(), out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)

if qdtype == 'auto':
assert quantize_attrs['shifted'] == 'True'
assert quantize_attrs['shifted_output'] == 'True'
bias_s32 = collect_param(fc_layer_quantized, 'dense%d_bias_quantize_s32' % number)
assert bias_s32.dtype == np.int32
bias_shifted = get_shifted_bias(quantize_attrs, weights_int8, weights_scale, bias_int8, bias_scale)
assert_almost_equal(bias_s32, bias_shifted, rtol=1e-3, atol=1e-3)
else:
assert 'shifted' not in quantize_attrs
assert 'shifted_output' not in quantize_attrs
bias = collect_param(fc_layer_quantized, 'dense%d_bias_quantize' % number)
assert bias.dtype == np.int8

Expand All @@ -1349,7 +1349,7 @@ def check(number, qdtype):


@with_seed()
def test_onednn_shifted_quantize_fc_fc():
def test_mkldnn_asymmetric_quantize_fc_fc():
batch_size = 2
if not is_test_for_mkldnn():
print("Test only for mkldnn")
Expand Down

0 comments on commit 42a48a7

Please sign in to comment.