Skip to content

Commit

Permalink
[QNN EP] Initial INT4 support (#21171)
Browse files Browse the repository at this point in the history
### Description
- Adds support for int4 quantized weights (per-tensor and per-channel)
on QNN EP
- Adds test script that creates an INT4 qdq model with a Conv
- Adds a unit tests demonstrating accuracy issues.



### Motivation and Context
This is the next step in being able to run models that use 4-bit
quantized weights on QNN EP.
  • Loading branch information
adrianlizarraga authored Jul 10, 2024
1 parent 1b82d83 commit 5753f8d
Show file tree
Hide file tree
Showing 12 changed files with 522 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
int32_t elem_data_type = 0;
ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(input_1.node_arg, elem_data_type));

const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) ||
const bool is_signed_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) ||
(elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT8) ||
(elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16);
ORT_RETURN_IF_NOT(is_signed_type, "Conv weights must be of a signed quantized type if quantized per-channel");

Expand Down
39 changes: 35 additions & 4 deletions onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <utility>
#include <vector>

#include "qnn_model_wrapper.h"
#include "core/common/safeint.h"
Expand Down Expand Up @@ -313,7 +315,8 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vector<uint32_t
}

Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name,
std::vector<int32_t>& zero_points) const {
/*out*/ std::vector<int32_t>& zero_points,
/*out*/ int32_t& onnx_data_type) const {
const auto& graph_initializers = GetInitializerTensors();
auto iter = graph_initializers.find(initializer_name);
ORT_RETURN_IF(iter == graph_initializers.end(), "Unable to find initializer for zero-point(s): ",
Expand All @@ -323,13 +326,14 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name,
ORT_RETURN_IF_NOT(zp_tensor_proto->has_data_type(), "Expected zero-point initializer ", initializer_name.c_str(),
" to have a proto data type.");

const int32_t onnx_data_type = zp_tensor_proto->data_type();
onnx_data_type = zp_tensor_proto->data_type();
std::vector<uint8_t> initializer_bytes;

ORT_RETURN_IF_ERROR(UnpackInitializerData(*zp_tensor_proto, initializer_bytes));

switch (onnx_data_type) {
// QNN use -offset for some reason
case ONNX_NAMESPACE::TensorProto_DataType_INT4: // INT4 zero-points are unpacked as 8-bit values for QNN
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
auto int8_span = ReinterpretAsSpan<const int8_t>(gsl::make_span(initializer_bytes));
std::transform(int8_span.begin(), int8_span.end(), std::back_inserter(zero_points),
Expand All @@ -338,6 +342,7 @@ Status QnnModelWrapper::UnpackZeroPoints(const std::string& initializer_name,
});
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: // UINT4 zero-points are unpacked as 8-bit values for QNN
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
auto uint8_span = ReinterpretAsSpan<const uint8_t>(gsl::make_span(initializer_bytes));
std::transform(uint8_span.begin(), uint8_span.end(), std::back_inserter(zero_points),
Expand Down Expand Up @@ -584,10 +589,36 @@ void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector<std::st
Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
std::vector<uint8_t>& unpacked_tensor) const {
if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) {
return onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), unpacked_tensor);
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(),
unpacked_tensor));
} else {
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor));
}

int32_t onnx_data_type = initializer.data_type();

// If this is an int4, we need to unpack it because QNN treats int4 as a full int8.
if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer);
const size_t num_elems = shape.Size();
std::vector<uint8_t> packed_int4_bytes = std::move(unpacked_tensor);
unpacked_tensor = std::vector<uint8_t>(num_elems);

auto dst = gsl::make_span(reinterpret_cast<int8_t*>(unpacked_tensor.data()), unpacked_tensor.size());
auto src = gsl::make_span(reinterpret_cast<const Int4x2*>(packed_int4_bytes.data()), packed_int4_bytes.size());
ORT_RETURN_IF_NOT(Int4x2::Unpack(dst, src), "Failed to unpack Tensor<Int4x2> for QNN");
} else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer);
const size_t num_elems = shape.Size();
std::vector<uint8_t> packed_int4_bytes = std::move(unpacked_tensor);
unpacked_tensor = std::vector<uint8_t>(num_elems);

auto dst = gsl::make_span(reinterpret_cast<uint8_t*>(unpacked_tensor.data()), unpacked_tensor.size());
auto src = gsl::make_span(reinterpret_cast<const UInt4x2*>(packed_int4_bytes.data()), packed_int4_bytes.size());
ORT_RETURN_IF_NOT(UInt4x2::Unpack(dst, src), "Failed to unpack Tensor<UInt4x2> for QNN");
}

return onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor);
return Status::OK();
}

} // namespace qnn
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ class QnnModelWrapper {
Status UnpackScales(const std::string& initializer_name, std::vector<float>& scales) const;

// Unpack zero-points from initializer and convert to int32_t (1 zero-point for per-tensor, > 1 for per-channel).
Status UnpackZeroPoints(const std::string& initializer_name, std::vector<int32_t>& zero_points) const;
Status UnpackZeroPoints(const std::string& initializer_name,
/*out*/ std::vector<int32_t>& zero_points,
/*out*/ int32_t& onnx_data_type) const;

// Checks if a tensor in the ONNX graph is per-channel quantized.
Status IsPerChannelQuantized(const onnxruntime::NodeUnitIODef& io_def,
Expand Down
140 changes: 124 additions & 16 deletions onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include "QnnTypes.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"

#define ALIGN_PTR_UP(ptr, align, type) \
reinterpret_cast<type>((reinterpret_cast<std::uintptr_t>(ptr) + (align)-1) & ~((align)-1))

namespace onnxruntime {
namespace qnn {

Expand Down Expand Up @@ -38,9 +41,10 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const {
return QnnQuantParamsWrapper(*this);
}

// Initializes by copying from a Qnn_QuantizeParams_t.
Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
if (scale_offset_data_) {
scale_offset_data_.reset(nullptr);
if (per_channel_data_) {
per_channel_data_.reset(nullptr);
params_ = QNN_QUANTIZE_PARAMS_INIT;
}

Expand All @@ -51,6 +55,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {

switch (params.quantizationEncoding) {
case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET:
case QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET:
params_ = params;
break;
case QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: {
Expand All @@ -63,27 +68,63 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params) {
const uint32_t num_elems = params.axisScaleOffsetEncoding.numScaleOffsets;

if (num_elems > 0) {
scale_offset_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_elems);
gsl::span<Qnn_ScaleOffset_t> src_span(params.axisScaleOffsetEncoding.scaleOffset, num_elems);
std::copy(src_span.begin(), src_span.end(), scale_offset_data_.get());
params_.axisScaleOffsetEncoding.scaleOffset = scale_offset_data_.get();
const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t);
constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t);
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*);

std::memcpy(aligned_dst, params.axisScaleOffsetEncoding.scaleOffset, num_bytes);
params_.axisScaleOffsetEncoding.scaleOffset = aligned_dst;
} else {
params_.axisScaleOffsetEncoding.scaleOffset = nullptr;
}
break;
}
case QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: {
const uint32_t num_elems = params.bwAxisScaleOffsetEncoding.numElements;

params_.encodingDefinition = params.encodingDefinition;
params_.quantizationEncoding = params.quantizationEncoding;
params_.bwAxisScaleOffsetEncoding.axis = params.bwAxisScaleOffsetEncoding.axis;
params_.bwAxisScaleOffsetEncoding.bitwidth = params.bwAxisScaleOffsetEncoding.bitwidth;
params_.bwAxisScaleOffsetEncoding.numElements = num_elems;

// Deep copy the scales[] and offsets[] arrays
if (num_elems > 0) {
const size_t num_scale_bytes = num_elems * sizeof(float);
const size_t num_zp_bytes = num_elems * sizeof(int32_t);
const size_t num_bytes = num_scale_bytes + num_zp_bytes;
constexpr std::uintptr_t align = alignof(float);
static_assert(alignof(float) == alignof(int32_t));

per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*);
char* zps_begin = scales_begin + num_scale_bytes;

std::memcpy(scales_begin, params.bwAxisScaleOffsetEncoding.scales, num_scale_bytes);
std::memcpy(zps_begin, params.bwAxisScaleOffsetEncoding.offsets, num_zp_bytes);
params_.bwAxisScaleOffsetEncoding.scales = reinterpret_cast<float*>(scales_begin);
params_.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast<int32_t*>(zps_begin);
} else {
params_.bwAxisScaleOffsetEncoding.scales = nullptr;
params_.bwAxisScaleOffsetEncoding.offsets = nullptr;
}
break;
}
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding);
}

return Status::OK();
}

// Initialize this object from a (potentially) quantized ONNX tensor.
// QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers.
Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& io_def) {
const std::optional<NodeUnitIODef::QuantParam>& ort_quant_params = io_def.quant_param;

if (scale_offset_data_) {
scale_offset_data_.reset(nullptr);
if (per_channel_data_) {
per_channel_data_.reset(nullptr);
params_ = QNN_QUANTIZE_PARAMS_INIT;
}

Expand All @@ -98,17 +139,25 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con

ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(ort_quant_params->scale.Name(), scales));

bool is_int4_type = false;

if (ort_quant_params->zero_point != nullptr) {
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points));
int32_t onnx_tp_type = 0;
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackZeroPoints(ort_quant_params->zero_point->Name(), zero_points,
onnx_tp_type));

is_int4_type = (onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) ||
(onnx_tp_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4);
}

const bool is_per_tensor = scales.size() == 1;

if (is_per_tensor) {
// QNN uses different structs to represent quantization parameters depending on
// - per-tensor vs per-channel
// - int4 vs not int4
if (is_per_tensor && !is_int4_type) {
params_.encodingDefinition = QNN_DEFINITION_DEFINED;
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;

// Parse scale & zero_point
params_.scaleOffsetEncoding.scale = scales[0];

if (ort_quant_params->zero_point != nullptr) {
Expand All @@ -117,8 +166,62 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
} else {
params_.scaleOffsetEncoding.offset = 0;
}
} else {
// Per-channel quantization.
} else if (is_per_tensor && is_int4_type) {
params_.encodingDefinition = QNN_DEFINITION_DEFINED;
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET;
params_.bwScaleOffsetEncoding.bitwidth = 4;
params_.bwScaleOffsetEncoding.scale = scales[0];

if (ort_quant_params->zero_point != nullptr) {
ORT_RETURN_IF_NOT(zero_points.size() == 1, "Expected one zero-point value");
params_.bwScaleOffsetEncoding.offset = zero_points[0];
} else {
params_.bwScaleOffsetEncoding.offset = 0;
}
} else if (!is_per_tensor && is_int4_type) {
const auto* io_shape = io_def.node_arg.Shape();
ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape");
const int32_t io_rank = io_shape->dim_size();

constexpr int64_t DEFAULT_QDQ_AXIS = 1;
int64_t axis = ort_quant_params->axis.value_or(DEFAULT_QDQ_AXIS);
if (axis < 0) {
axis += io_rank;
}
ORT_RETURN_IF_NOT(axis >= 0 && axis < io_rank,
"Quantization axis must be within the range [0, rank - 1]");

const size_t num_elems = scales.size();
const bool no_zero_points = zero_points.empty();
ORT_RETURN_IF_NOT(num_elems > 1, "Expected more than one scale value");
ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems,
"Expected the same number of zero-points and scales for per-channel quantization");

params_.encodingDefinition = QNN_DEFINITION_DEFINED;
params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
params_.bwAxisScaleOffsetEncoding.axis = static_cast<int32_t>(*(ort_quant_params->axis));
params_.bwAxisScaleOffsetEncoding.bitwidth = 4;
params_.bwAxisScaleOffsetEncoding.numElements = static_cast<uint32_t>(num_elems);

const size_t num_scale_bytes = num_elems * sizeof(float);
const size_t num_zp_bytes = num_elems * sizeof(int32_t);
const size_t num_bytes = num_scale_bytes + num_zp_bytes;
constexpr std::uintptr_t align = alignof(float);
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);

char* scales_begin = ALIGN_PTR_UP(per_channel_data_.get(), align, char*);
char* zps_begin = scales_begin + num_scale_bytes;
gsl::span<float> scales_span(reinterpret_cast<float*>(scales_begin), num_elems);
gsl::span<int32_t> zps_span(reinterpret_cast<int32_t*>(zps_begin), num_elems);

for (size_t i = 0; i < num_elems; i++) {
scales_span[i] = scales[i];
zps_span[i] = no_zero_points ? 0 : zero_points[i];
}

params_.bwAxisScaleOffsetEncoding.scales = scales_span.data();
params_.bwAxisScaleOffsetEncoding.offsets = zps_span.data();
} else if (!is_per_tensor && !is_int4_type) {
const auto* io_shape = io_def.node_arg.Shape();
ORT_RETURN_IF(io_shape == nullptr, "Input/output tensor proto must have a shape");
const int32_t io_rank = io_shape->dim_size();
Expand All @@ -140,8 +243,11 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
ORT_RETURN_IF_NOT(no_zero_points || zero_points.size() == num_elems,
"Expected the same number of zero-points and scales for per-channel quantization");

scale_offset_data_ = std::make_unique<Qnn_ScaleOffset_t[]>(num_elems);
gsl::span<Qnn_ScaleOffset_t> data_span(scale_offset_data_.get(), num_elems);
const size_t num_bytes = num_elems * sizeof(Qnn_ScaleOffset_t);
constexpr std::uintptr_t align = alignof(Qnn_ScaleOffset_t);
per_channel_data_ = std::make_unique<char[]>(num_bytes + align);
Qnn_ScaleOffset_t* aligned_dst = ALIGN_PTR_UP(per_channel_data_.get(), align, Qnn_ScaleOffset_t*);
gsl::span<Qnn_ScaleOffset_t> data_span(aligned_dst, num_elems);

for (size_t i = 0; i < num_elems; i++) {
data_span[i].scale = scales[i];
Expand All @@ -151,6 +257,8 @@ Status QnnQuantParamsWrapper::Init(const QnnModelWrapper& qnn_model_wrapper, con
params_.axisScaleOffsetEncoding.axis = static_cast<int32_t>(axis);
params_.axisScaleOffsetEncoding.numScaleOffsets = static_cast<uint32_t>(num_elems);
params_.axisScaleOffsetEncoding.scaleOffset = data_span.data();
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected tensor kind for QuantParamsWrapper::Init()");
}

return Status::OK();
Expand Down
16 changes: 11 additions & 5 deletions onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ class QnnQuantParamsWrapper {
(include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET));
}

bool IsPerChannel(bool include_bw = false) const {
bool IsPerChannel() const {
return params_.encodingDefinition == QNN_DEFINITION_DEFINED &&
(params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET ||
(include_bw && params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET));
(params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET));
}

// Handle transposing of a per-channel quantized tensor. The quantization parameter's axis
// must be transposed using the inverse permutation of the Transpose.
template <typename IntType>
Status HandleTranspose(gsl::span<const IntType> perm) {
if (!IsPerChannel(true)) {
if (!IsPerChannel()) {
return Status::OK();
}

Expand All @@ -82,7 +82,7 @@ class QnnQuantParamsWrapper {
template <typename IntType>
Status HandleUnsqueeze(gsl::span<const IntType> orig_shape,
gsl::span<const IntType> new_shape) {
if (!IsPerChannel(true)) {
if (!IsPerChannel()) {
return Status::OK();
}

Expand Down Expand Up @@ -134,7 +134,13 @@ class QnnQuantParamsWrapper {

private:
Qnn_QuantizeParams_t params_;
std::unique_ptr<Qnn_ScaleOffset_t[]> scale_offset_data_; // Stores per-channel scales and offsets

// Stores arrays of per-channel scales and offsets. Fields in params_ point to this data.
//
// Use an opaque array of bytes because QNN uses different data layouts depending on the quantization encoding:
// - QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET: array of scale/zp pairs [{scale0, zp0}, {scale1, zp1}, ...]
// - QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET: parallel arrays for scales and zps [scale0, ...] [zp0, zp1, ...]
std::unique_ptr<char[]> per_channel_data_;
};

} // namespace qnn
Expand Down
Loading

0 comments on commit 5753f8d

Please sign in to comment.