Skip to content

Commit

Permalink
Add FP requantize flow for llvm target
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemist committed Dec 13, 2021
1 parent 5557b8c commit 18824b4
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 12 deletions.
38 changes: 38 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,44 @@ def requantize(
)


def upward(data):
r"""Upward operator.
UPWARD is the standard rounding except at midpoints where the value
is rounded to positive infinity (for example, -1.5 rounds to -1).
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.upward(data)


def tonearest(data):
r"""Tonearest operator.
TONEAREST is the standard rounding where the value is rounded away
from zero at midpoints (for example, -1.5 rounds to -2).
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.tonearest(data)


def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
r"""Quantize op
This operator takes float32 as input and produces quantized int8 or unit8 as output.
Expand Down
27 changes: 25 additions & 2 deletions python/tvm/topi/x86/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,26 @@
# under the License.
"""Common x86 related utilities"""
import tvm
import tvm._ffi


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41")
def target_has_sse41(target):
return (
target_has_sse42(target)
or target_has_avx(target)
or target_has_avx2(target)
or target_has_avx512(target)
or target_has_vnni(target)
or target
in {
"btver2",
"penryn",
}
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42")
def target_has_sse42(target):
return (
target_has_avx(target)
Expand All @@ -42,6 +60,7 @@ def target_has_sse42(target):
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx")
def target_has_avx(target):
return (
target_has_avx2(target)
Expand All @@ -51,6 +70,7 @@ def target_has_avx(target):
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2")
def target_has_avx2(target):
return (
target_has_avx512(target)
Expand All @@ -70,6 +90,7 @@ def target_has_avx2(target):
)


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512")
def target_has_avx512(target):
return target in {
"skylake-avx512",
Expand All @@ -82,26 +103,28 @@ def target_has_avx512(target):
"cascadelake",
"icelake-client",
"rocketlake",
"icelake",
"icelake-server",
"tigerlake",
"cooperlake",
"sapphirerapids",
}


@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni")
def target_has_vnni(target):
return target in {
"cascadelake",
"icelake-client",
"rocketlake",
"icelake",
"icelake-server",
"tigerlake",
"cooperlake",
"sapphirerapids",
"alderlake",
}


@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
def get_simd_32bit_lanes():
mcpu = tvm.target.Target.current().mcpu
fp32_vec_len = 4
Expand Down
233 changes: 228 additions & 5 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/qnn/attrs.h>

#include "../../op/op_common.h"
#include "../../transforms/infer_layout_utils.h"
#include "../../transforms/pattern_utils.h"
#include "../utils.h"
Expand Down Expand Up @@ -111,6 +112,107 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param));
}

bool has_current_target_sse41_support() {
auto target = Target::Current(true);
Optional<String> mcpu =
target.defined() ? target->GetAttr<String>("mcpu") : Optional<String>(nullptr);
auto target_has_sse41_fn_ptr = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41");
ICHECK(target_has_sse41_fn_ptr) << "Function tvm.topi.x86.utils.target_has_sse41 not found";
return mcpu && (*target_has_sse41_fn_ptr)(mcpu.value());
}

/*
* \brief TONEAREST is the standard rounding where the value is rounded away
* from zero at midpoints (for example, -1.5 rounds to -2).
* \param input_tensor The input tensor to rounding op.
* \return The sequence of existing Relay ops.
*/
Expr Tonearest(const Expr& input_tensor) {
if (has_current_target_sse41_support()) return Round(input_tensor);

auto half = MakeConstantScalar(DataType::Float(64), 0.5f);
auto zero = MakeConstantScalar(DataType::Float(64), 0.f);
auto pos_one = MakeConstantScalar(DataType::Float(64), +1.f);
auto neg_one = MakeConstantScalar(DataType::Float(64), -1.f);
auto multiplier = Where(Less(input_tensor, zero), neg_one, pos_one);
auto half_multiplied = Multiply(half, multiplier);
auto input_tensor_biased = Add(input_tensor, half_multiplied);
auto input_tensor_biased_multiplied = Multiply(input_tensor_biased, multiplier);
auto input_tensor_biased_multiplied_int64 =
Cast(input_tensor_biased_multiplied, DataType::Int(64));
auto input_tensor_biased_multiplied_float64 =
Cast(input_tensor_biased_multiplied_int64, DataType::Float(64));
auto input_tensor_rounded = Multiply(input_tensor_biased_multiplied_float64, multiplier);
return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor);
}

/*
* \brief UPWARD is the standard rounding except at midpoints where the value
* is rounded to positive infinity (for example, -1.5 rounds to -1).
* \param input_tensor The input tensor to rounding op.
* \return The sequence of existing Relay ops.
*/
Expr Upward(const Expr& input_tensor) {
auto half = MakeConstantScalar(DataType::Float(64), 0.5f);
auto input_tensor_biased = Add(input_tensor, half);
if (has_current_target_sse41_support()) return Floor(input_tensor_biased);

auto zero = MakeConstantScalar(DataType::Float(64), 0.f);
auto one = MakeConstantScalar(DataType::Float(64), +1.f);
auto input_tensor_biased_int64 = Cast(input_tensor_biased, DataType::Int(64));
auto input_tensor_biased_float64 = Cast(input_tensor_biased_int64, DataType::Float(64));
auto is_subtraction_not_necessary =
LogicalOr(Equal(input_tensor_biased, input_tensor_biased_float64),
GreaterEqual(input_tensor_biased, zero));
auto input_tensor_rounded = Where(is_subtraction_not_necessary, input_tensor_biased_float64,
Subtract(input_tensor_biased_float64, one));
return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor);
}

// Positional relay function to create tonearest operator
// used by frontend FFI.
Expr MakeTonearest(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) {
ICHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
return Tonearest(data);
}

RELAY_REGISTER_OP("tonearest")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", MakeTonearest);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.tonearest").set_body_typed([](Expr data) {
static const Op& op = Op::Get("tonearest");
return Call(op, {data}, Attrs(), {});
});

// Positional relay function to create upward operator
// used by frontend FFI.
Expr MakeUpward(const Attrs& attrs, const Array<Expr>& new_args,
const Array<tvm::relay::Type>& types) {
ICHECK_EQ(new_args.size(), 1);
auto& data = new_args[0];
return Upward(data);
}

RELAY_REGISTER_OP("upward")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", MakeUpward);

TVM_REGISTER_GLOBAL("relay.qnn.op._make.upward").set_body_typed([](Expr data) {
static const Op& op = Op::Get("upward");
return Call(op, {data}, Attrs(), {});
});

// Lowering of qnn.requantize op

/*
Expand All @@ -119,7 +221,7 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
* \param param The requantize op attrs.
* \param input_shape The input tensor shape of the requantize op.
* \return The sequence of existing Relay ops.
* \note Requantization using only integer computation. Here, the computation is
* \note RequantizationInt using only integer computation. Here, the computation is
* converted to a fixed point computation by computing output multiplier
* and shift. This is useful, if the target device does not support/have
* very expensive floating point computations.
Expand All @@ -131,10 +233,10 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
* 4) Add the output zero point.
* 5) Cast to the out_dtype.
*/
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
auto tensor = Cast(input_tensor, DataType::Int(32));
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
Expand Down Expand Up @@ -208,6 +310,127 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
return Cast(clipped_t, out_dtype);
}

// Lowering of qnn.requantize op

/*
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
* \param param The requantize op attrs.
* \param input_shape The input tensor shape of the requantize op.
* \return The sequence of existing Relay ops.
* \note RequantizationFP using floating computation. All multiplication/sub/sum
* occurs in floating point data type and only at the end is converted to
* int32 data type and clamped for output data type.
*
* The whole computation this can be broken down into following steps
* 1) Subtract the input zero point.
* 2) Perform multiplication.
* 3) Add the output zero point.
* 4) Cast to the out_dtype.
*/
Expr RequantizeLowerFP(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
auto tensor = Cast(input_tensor, DataType::Float(64));
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
// Broadcast input zero point if needed.
int rank = static_cast<int>(input_shape.size());
int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis;
Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point,
{
-1,
}),
rank, {axis});
tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Float(64)));
}

// 2) If the input and output scales are same, we can skip the multiplication. Check
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
auto scaled_fp64_t = tensor;
double output_scale_float = GetScalarFromConstant<float>(output_scale);
if (IsConstScalar(input_scale)) {
// This is per-tensor quantization. Single scale.
double input_scale_float = GetScalarFromConstant<float>(input_scale);
double double_multiplier = input_scale_float / output_scale_float;
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
double multiplier = double_multiplier;
auto m_scalar = MakeConstantScalar(DataType::Float(64), multiplier);
scaled_fp64_t = Multiply(m_scalar, scaled_fp64_t);
}

} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
double output_scale_float = GetScalarFromConstant<float>(output_scale);
for (auto input_axis_scale : input_axis_scales) {
double multiplier = static_cast<double>(input_axis_scale) / output_scale_float;
double_multipliers.push_back(multiplier);
}
int axis = param->axis;
axis = (axis == -1) ? input_shape.size() - 1 : axis;

auto fixed_pt_multiplier_expr = MakeConstantTensor(
DataType::Float(64), {(int64_t)double_multipliers.size()}, double_multipliers);
size_t n_dim = input_shape.size();
auto exp_fixed_pt_multiplier_expr =
ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {axis});

scaled_fp64_t = Multiply(scaled_fp64_t, exp_fixed_pt_multiplier_expr);
}

// 3) Add the output zero point.
auto shifted_fp64_t = scaled_fp64_t;
if (!IsEqualScalar(output_zero_point, zero_scalar)) {
shifted_fp64_t = Add(shifted_fp64_t, Cast(output_zero_point, DataType::Float(64)));
}

if (param->rounding == "UPWARD") {
shifted_fp64_t = Upward(shifted_fp64_t);
} else /*if (param->rounding == "TONEAREST")*/ {
shifted_fp64_t = Tonearest(shifted_fp64_t);
}

shifted_fp64_t = Cast(shifted_fp64_t, DataType::Int(32));
// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
// multiplication keeps the value in int32 range.
if (out_dtype == DataType::Int(32)) {
return shifted_fp64_t;
}

auto q_min = GetQmin(out_dtype);
auto q_max = GetQmax(out_dtype);
auto clipped_t = Clip(shifted_fp64_t, q_min, q_max);
return Cast(clipped_t, out_dtype);
}

// Lowering of qnn.requantize op
/*
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
* \param param The requantize op attrs.
* \param input_shape The input tensor shape of the requantize op.
* \return The sequence of existing Relay ops.
*/
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Expr& output_scale,
const Expr& output_zero_point, const RequantizeAttrs* param,
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "llvm") {
return RequantizeLowerFP(input_tensor, input_scale, input_zero_point, output_scale,
output_zero_point, param, input_shape, out_dtype);
} else {
return RequantizeLowerInt(input_tensor, input_scale, input_zero_point, output_scale,
output_zero_point, param, input_shape, out_dtype);
}
}

/*
* \brief Forward rewrite the requantize op.
* \param ref_call The original call that will be lowered.
Expand Down
Loading

0 comments on commit 18824b4

Please sign in to comment.