From 060f1b822a6b71e1db07b94f68fdfb53fe47a7d0 Mon Sep 17 00:00:00 2001 From: Weiwen Xia Date: Mon, 14 Mar 2022 18:23:08 -0700 Subject: [PATCH] Add onednn quant backend (#74137) Summary: Resolve the conflicts in https://github.com/pytorch/pytorch/pull/69820 jerryzh168 Please review. Thanks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74137 Reviewed By: samdow Differential Revision: D34840477 Pulled By: jerryzh168 fbshipit-source-id: 8aa60981ff7be211a1609644f273b16d18efd425 (cherry picked from commit de76bb808b315e9a2e45d8c5f1c1233a47d669c4) --- aten/src/ATen/Context.cpp | 4 + .../native/quantized/cpu/conv_serialization.h | 15 ++ .../native/quantized/cpu/fbgemm_utils.cpp | 11 ++ .../ATen/native/quantized/cpu/onednn_utils.h | 151 +++++++++++++++ aten/src/ATen/native/quantized/cpu/qconv.cpp | 173 +++++++++++++++++ .../native/quantized/cpu/qconv_dynamic.cpp | 52 +++++ .../native/quantized/cpu/qconv_prepack.cpp | 179 +++++++++++++++++- .../native/quantized/cpu/qconv_unpack.cpp | 30 +++ .../src/ATen/native/quantized/cpu/qlinear.cpp | 76 ++++++++ .../native/quantized/cpu/qlinear_dynamic.cpp | 94 +++++++++ .../native/quantized/cpu/qlinear_prepack.cpp | 106 +++++++++++ .../native/quantized/cpu/qlinear_unpack.cpp | 8 + c10/core/QEngine.h | 4 + test/ao/sparsity/test_kernels.py | 8 + .../core/test_quantized_module.py | 26 +-- test/quantization/core/test_quantized_op.py | 53 ++++-- torch/ao/quantization/qconfig.py | 21 +- torch/backends/quantized/__init__.py | 4 +- torch/testing/_internal/common_quantized.py | 29 ++- 19 files changed, 1007 insertions(+), 37 deletions(-) create mode 100644 aten/src/ATen/native/quantized/cpu/onednn_utils.h diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 98590b266be40..6bf5e982ded8d 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -236,6 +236,10 @@ const std::vector& Context::supportedQEngines() { engines.push_back(at::kNoQEngine); #endif // C10_MOBILE +#if AT_MKLDNN_ENABLED() + engines.push_back(at::kONEDNN); +#endif + #ifdef USE_FBGEMM if (fbgemm::fbgemmSupportedCPU()) { engines.push_back(at::kFBGEMM); diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h index cf5c04977b6a1..369f54b439614 100644 --- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h +++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -358,6 +359,20 @@ c10::intrusive_ptr> deserialize_conv( ); } #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedConvWeightsOnednn::prepack( + weight.value(), + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + ); + } +#endif // AT_MKLDNN_ENABLED() TORCH_CHECK( false, "Didn't find engine for when deserializing ConvPackedParams: ", diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 8d4bfd5a74da4..369ee6744624c 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -470,6 +471,16 @@ int register_linear_params() { std::move(weight), std::move(bias)); } #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (at::globalContext().qEngine() == at::QEngine::ONEDNN) { + TORCH_CHECK( + weight.scalar_type() == at::kQInt8, + "ONEDNN only supports INT8 bit width currently. Got ", + c10::toString(weight.scalar_type())); + return PackedLinearWeightsOnednn::prepack( + std::move(weight), std::move(bias)); + } +#endif // #if AT_MKLDNN_ENABLED() TORCH_CHECK(false, "Unknown qengine"); }) .def("bias", [](const c10::intrusive_ptr& self) { diff --git a/aten/src/ATen/native/quantized/cpu/onednn_utils.h b/aten/src/ATen/native/quantized/cpu/onednn_utils.h new file mode 100644 index 0000000000000..4ee8e8737fb22 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/onednn_utils.h @@ -0,0 +1,151 @@ +#pragma once + +#include +#if AT_MKLDNN_ENABLED() +#include +#include +#include +#include + +struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { + PackedLinearWeightsOnednn( + std::unique_ptr weight, + c10::optional bias, + at::Tensor orig_weight, + c10::optional orig_bias) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)) {} + std::unique_ptr weight_; + c10::optional bias_; + at::Tensor orig_weight_; + c10::optional orig_bias_; + + at::Tensor apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + at::Tensor apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + + std::tuple> unpack() override; + + c10::optional bias() override { + return orig_bias_; + } + + static c10::intrusive_ptr prepack( + at::Tensor weight, + c10::optional bias); + + private: + template + at::Tensor apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point); + + template + at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); +}; + +template +struct PackedConvWeightsOnednn : public ConvPackedParamsBase { + PackedConvWeightsOnednn( + std::unique_ptr weight, + c10::optional bias, + at::Tensor orig_weight, + c10::optional orig_bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + uint8_t transpose) + : weight_(std::move(weight)), + bias_(std::move(bias)), + orig_weight_(std::move(orig_weight)), + orig_bias_(std::move(orig_bias)), + stride_(std::move(stride)), + padding_(std::move(padding)), + output_padding_(std::move(output_padding)), + dilation_(std::move(dilation)), + groups_(groups), + transpose_(transpose) {} + + std::unique_ptr weight_; + c10::optional bias_; + at::Tensor orig_weight_; + c10::optional orig_bias_; + torch::List stride_; + torch::List padding_; + torch::List output_padding_; + torch::List dilation_; + int64_t groups_; + uint8_t transpose_; + + at::Tensor apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) override; + + at::Tensor apply_dynamic( + const at::Tensor& input, + bool reduce_range) override; + + std::tuple> unpack() override; + + static c10::intrusive_ptr> prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose); + + torch::List stride() const override { + return stride_; + } + + torch::List padding() const override { + return padding_; + } + + torch::List output_padding() const override { + return output_padding_; + } + + torch::List dilation() const override { + return dilation_; + } + + int64_t groups() const override { + return groups_; + } + + bool transpose() const override { + return (bool)transpose_; + } + + private: + template + at::Tensor apply_impl( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point); +}; + +#endif // #if AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 6948b4a526dcc..b8143676b605f 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include #include @@ -1148,6 +1150,177 @@ template at::Tensor PackedConvWeightsQnnp<3>::apply_impl( #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() +template +at::Tensor PackedConvWeightsOnednn::apply( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) { + return apply_impl(input, output_scale, output_zero_point); +} + +template +at::Tensor PackedConvWeightsOnednn::apply_relu( + const at::Tensor& input, + double output_scale, + int64_t output_zero_point) { + return apply_impl(input, output_scale, output_zero_point); +} + +template +template +at::Tensor PackedConvWeightsOnednn::apply_impl( + const at::Tensor& act, + double output_scale, + int64_t output_zero_point) { + std::string func_name = "quantized::conv"; + if (transpose()) { + func_name += "_transpose"; + } + func_name += std::to_string(kSpatialDim) + "d"; + if (kReluFused) { + func_name += "_relu"; + } + ConvDimChecks( + act.ndimension(), stride().size(), padding().size(), + output_padding().size(), dilation().size(), func_name, transpose()); + TORCH_CHECK(act.scalar_type() == c10::ScalarType::QUInt8, + func_name, " (ONEDNN): data type of input should be QUint8."); + + // src + auto act_contig = act.contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d); + auto src_dims = act_contig.sizes().vec(); + auto src_data_type = dnnl::memory::data_type::u8; + auto src_desc = ideep::tensor::desc(src_dims, src_data_type, + kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc); + ideep::tensor src; + src.init(src_desc, act_contig.data_ptr()); + // weights & bias + ideep::tensor& weights = *(weight_.get()); + bool with_bias = bias_.has_value(); + const auto& kernel_size = weights.get_dims(); + // dst + const std::vector& input_size = src.get_dims(); + std::vector output_sizes; + if (transpose()) { + // Prepacked weight format: [o, i, ...] + const int N = act.size(0); // batch size + const int C = act.size(1); // input channels + const int M = weights.get_dim(0); // output channels + const int D = kSpatialDim == 2 ? 1 : act.size(2); // input depth + const int H = act.size(kSpatialDim); // input height + const int W = act.size(kSpatialDim + 1); // input width + const int KH = weights.get_dim(kSpatialDim); // kernel height + const int KW = weights.get_dim(kSpatialDim + 1); // kernel width + const int KD = kSpatialDim == 2 ? 1 : weights.get_dim(2); // kernel depth + TORCH_CHECK(C == groups() * weights.get_dim(1), // weight: [o, i, ...] + func_name, " (ONEDNN): input channel number should be ", + groups() * weights.get_dim(1), ", but got ", C); + auto output_shape = MakeDeConvOutputShape( + N, + M, + kSpatialDim == 2 ? std::vector{H, W} : std::vector{D, H, W}, + kSpatialDim == 2 ? std::vector{KH, KW} : std::vector{KD, KH, KW}, + stride(), + padding(), + output_padding(), + dilation()); + output_sizes = c10::IntArrayRef(output_shape).vec(); + } else { + output_sizes = at::native::conv_output_size(input_size, kernel_size, padding().vec(), stride().vec(), dilation().vec()); + } + ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()}); + at::Tensor output = at::_empty_affine_quantized( + dst_dims, + device(c10::kCPU) + .dtype(c10::kQUInt8) + .memory_format(kSpatialDim == 2 ? + c10::MemoryFormat::ChannelsLast : + c10::MemoryFormat::ChannelsLast3d), + output_scale, + output_zero_point, + c10::nullopt); + if (output.numel() == 0) { + return output; + } + ideep::tensor dst({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}}, + output.data_ptr()); + // Parameters + const ideep::dims& strides = stride().vec(); + const ideep::dims& dilates = dilation().vec(); + const ideep::dims& padding_l = padding().vec(); + const ideep::dims& padding_r = padding().vec(); + const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/act.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal + const ideep::scale_t& weights_scales = weights.get_scale(); + const ideep::scale_t& dst_scales = ideep::scale_t(weights_scales.size(), 1.0/output_scale); // Scales of ONEDNN and PyTorch are reciprocal + const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, act.q_zero_point()); + const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point); + ideep::attr_t op_attr = kReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + op_attr.set_zero_points(DNNL_ARG_SRC, ideep::utils::tensor_zp_mask(1), {DNNL_RUNTIME_S32_VAL}); // runtime src zero point + if (with_bias) { + // Bias might be modified outside (e.g. by quantization bias correction). + // If so, update the prepacked bias as well. + if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) { + bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr()); + } + const auto& b = bias_.value(); + if (transpose()) { + ideep::convolution_transpose_forward::compute_v2( + src, weights, b, dst_dims, dst, + strides, padding_l, padding_r, dilates, + groups(), src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, + op_attr, dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference, + ideep::u8s8, ideep::engine::cpu_engine()); + } else { + ideep::convolution_forward::compute_v2( + src, weights, b, dst_dims, dst, + strides, dilates, padding_l, padding_r, groups(), + src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, + op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, + ideep::u8s8, ideep::engine::cpu_engine()); + } + } else { + if (transpose()) { + ideep::convolution_transpose_forward::compute_v2( + src, weights, dst_dims, dst, + strides, padding_l, padding_r, dilates, + groups(), src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, + op_attr, dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference, + ideep::u8s8, ideep::engine::cpu_engine()); + } else { + ideep::convolution_forward::compute_v2( + src, weights, dst_dims, dst, + strides, dilates, padding_l, padding_r, groups(), + src_scales, weights_scales, dst_scales, src_zero_points, dst_zero_points, + op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, + ideep::u8s8, ideep::engine::cpu_engine()); + } + } + return output; +} + +template at::Tensor PackedConvWeightsOnednn<2>::apply( + const at::Tensor& act, + double output_scale, + int64_t output_zero_point); + +template at::Tensor PackedConvWeightsOnednn<2>::apply_relu( + const at::Tensor& act, + double output_scale, + int64_t output_zero_point); + +template at::Tensor PackedConvWeightsOnednn<3>::apply( + const at::Tensor& act, + double output_scale, + int64_t output_zero_point); + +template at::Tensor PackedConvWeightsOnednn<3>::apply_relu( + const at::Tensor& act, + double output_scale, + int64_t output_zero_point); + +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp index 1fd077eeaed65..2f3a6ed8f3cdb 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -118,6 +119,57 @@ template at::Tensor PackedConvWeightsQnnp<3>::apply_dynamic( #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + +template +at::Tensor PackedConvWeightsOnednn::apply_dynamic( + const at::Tensor& input, + bool reduce_range) { + + // Find min/max of input + float x_max = 0, x_min = 0; + if (input.numel() > 0) { + x_min = input.min().item(); + x_max = input.max().item(); + } + + // Input tensor is quantized as 8-bit unsigned values + static constexpr int precision = 8; + static constexpr bool is_signed = false; + + // Calculate scale and zero point for quantization of input tensor + auto q_params = quant_utils::ChooseQuantizationParams( + /*min=*/x_min, + /*max=*/x_max, + /*qmin=*/is_signed ? -(1 << (precision - 1)) : 0, + /*qmax=*/ + is_signed ? ((1 << (precision - 1)) - 1) : (1 << precision) - 1, + /*preserve_sparsity=*/false, + /*force_scale_power_of_two=*/false, + /*reduce_range=*/reduce_range); + + // Quantize input + at::Tensor q_input = at::quantize_per_tensor( + input, q_params.scale, q_params.zero_point, c10::kQUInt8); + + at::Tensor out = + apply_impl(q_input, q_params.scale, q_params.zero_point); + + // TODO: Modify ideep to allow fp32 input & output + // to avoid explicit `quantize - dequantize` + return at::dequantize(out); +} + +template at::Tensor PackedConvWeightsOnednn<2>::apply_dynamic( + const at::Tensor& input, + bool reduce_range); + +template at::Tensor PackedConvWeightsOnednn<3>::apply_dynamic( + const at::Tensor& input, + bool reduce_range); + +#endif // AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 9a4762340e368..85edffef25b98 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -314,6 +315,165 @@ c10::intrusive_ptr> PackedConvWeightsQnnp< bool transpose); #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() +template +c10::intrusive_ptr> PackedConvWeightsOnednn< + kSpatialDim>:: + prepack( + at::Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List output_padding, + torch::List dilation, + int64_t groups, + bool transpose) { + TORCH_CHECK( + weight.ndimension() == kSpatialDim + 2, + "Weights are expected to have ", kSpatialDim + 2, " dimensions"); + TORCH_CHECK( + stride.size() == kSpatialDim, + "stride should contain ", kSpatialDim, " elements for ", + kSpatialDim, "D convolution."); + TORCH_CHECK( + padding.size() == kSpatialDim, + "Specify front/top/left padding only. " + "end/bottom/right padding assumed to be equal to front/top/left"); + TORCH_CHECK( + !transpose || output_padding.size() == kSpatialDim, + "quantized::conv_prepack: Specify top/left output padding " + "only. bottom/right padding assumed to be equal to top/left"); + TORCH_CHECK( + dilation.size() == kSpatialDim, + "dilation should contain ", kSpatialDim, " elements for ", + kSpatialDim, "D convolution."); + TORCH_CHECK( + !transpose || std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; }), + "quantized::conv_prepack: ONEDNN only supports zero output_padding."); + + // Weight + // Format: [OC IC//group KH KW] for conv; [IC OC//group KH KW] for deconv + auto dims = weight.sizes().vec(); + auto strides = stride.vec(); + auto padding_l = padding.vec(); + auto padding_r = padding.vec(); + auto dilates = dilation.vec(); + auto op_attr = ideep::attr_t(); + std::vector wgt_zero_points; + ideep::scale_t wgt_scales; + const int output_channels = transpose ? weight.size(1) * groups + : weight.size(0); + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + TORCH_CHECK( + weight.q_zero_point()==0, + "quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight," + " whose zero point must be 0."); + wgt_zero_points = std::vector(1, weight.q_zero_point()); + wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal + } else if (qtype == c10::kPerChannelAffine) { + TORCH_CHECK( + !transpose, + "Per Channel Quantization is currently disabled for transposed conv"); + wgt_zero_points.resize(output_channels); + wgt_scales.resize(output_channels); + for (int i = 0; i < output_channels; ++i) { + wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item(); + TORCH_CHECK( + wgt_zero_points[i]==0, + "quantized::qconv_prepack: ONEDNN only supports symmetric quantization of weight," + " whose zero point must be 0."); + wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item(); // Scales of ONEDNN and PyTorch are reciprocal + } + } else { + TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype)); + } + + // Set runtime src zero point + auto src_zero_point = {DNNL_RUNTIME_S32_VAL}; + op_attr.set_zero_points(DNNL_ARG_SRC, + ideep::utils::tensor_zp_mask(src_zero_point.size()), + src_zero_point); + at::Tensor weight_copy; + ideep::tensor::desc w_desc; + ideep::dims dims_iohw, dims_giohw; + ideep::tag w_tag = ideep::tag::any; + const bool with_groups = groups > 1; + if (transpose) { + w_desc = ideep::convolution_transpose_forward::expected_weights_desc( + dims, dnnl::memory::data_type::s8, + strides, padding_l, padding_r, dilates, groups, + dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference, + ideep::dims(), op_attr); + // convolution_transpose_forward::expected_weights_desc() gives format [i, o, ...], + // but ONEDNN requires [o, i, ...] for computation + dims_iohw = w_desc.get_dims(); + dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw; + std::vector perms(dims_giohw.size(), 0); // for permutation of weight + std::iota(perms.begin(), perms.end(), 0); + w_desc = w_desc.transpose(with_groups, with_groups + 1); + std::swap(perms[with_groups], perms[with_groups + 1]); + weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone(); + } else { + w_desc = ideep::convolution_forward::expected_weights_desc( + dims, dnnl::memory::data_type::s8, + strides, padding_l, padding_r, dilates, groups, + dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, + dnnl::memory::data_type::u8, ideep::dims(), op_attr); + weight_copy = weight.clone(); + } + if (with_groups) { + w_tag = kSpatialDim == 2 ? ideep::tag::goihw : ideep::tag::goidhw; + } else { + w_tag = kSpatialDim == 2 ? ideep::tag::oihw : ideep::tag::oidhw; + } + ideep::dims w_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups) + : w_desc.get_dims(); + ideep::tensor wgt = ideep::tensor( + ideep::tensor::desc({w_dims, dnnl::memory::data_type::s8, w_tag}, groups), + weight_copy.data_ptr()); + wgt.set_scale(wgt_scales); // Scales are needed for feed_from(). + ideep::tensor exp_wgt; + exp_wgt.init(w_desc); + exp_wgt.set_scale(wgt_scales); // Also for feed_from() + exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format + ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + packed_weight_p->set_scale(wgt_scales); + packed_weight_p->set_zero_point(wgt_zero_points); + std::unique_ptr weight_ptr(packed_weight_p); + // Bias + c10::optional onednn_bias{c10::nullopt}; + if (bias.has_value()) { + at::Tensor bias_vec = bias.value(); + TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)"); + TORCH_CHECK( + bias_vec.size(0) == output_channels, + "bias should have K elements: " + std::to_string(output_channels)); + auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32); + ideep::tensor packed_bias; + packed_bias.init(bias_desc, bias.value().data_ptr()); + onednn_bias = c10::optional(packed_bias); + } + auto ret_ptr = c10::make_intrusive>( + PackedConvWeightsOnednn{ + std::move(weight_ptr), + onednn_bias, + weight, + bias, + stride, + padding, + output_padding, + dilation, + groups, + transpose + }); + return ret_ptr; +} + +template struct PackedConvWeightsOnednn<2>; +template struct PackedConvWeightsOnednn<3>; +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { @@ -377,6 +537,14 @@ class QConvPackWeightInt8 final { } #endif +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedConvWeightsOnednn::prepack( + weight, bias, stride, padding, output_padding, dilation, groups, + transpose); + } +#endif + TORCH_CHECK( false, "Didn't find engine for operation quantized::conv2d_prepack ", @@ -438,8 +606,6 @@ class QConv1dPackWeightInt8 final { } #endif - - #ifdef USE_PYTORCH_QNNPACK if (ctx.qEngine() == at::QEngine::QNNPACK) { return PackedConvWeightsQnnp<2>::prepack( @@ -447,6 +613,15 @@ class QConv1dPackWeightInt8 final { transpose); } #endif + +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedConvWeightsOnednn<2>::prepack( + weight, bias, stride, padding, output_padding, dilation, groups, + transpose); + } +#endif + TORCH_CHECK( false, "Didn't find engine for operation quantized::conv1d_prepack ", diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index 5c9d964389f7b..987513848e170 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -120,6 +121,20 @@ template std::tuple> PackedConvWeightsQnnp 3>::unpack(); #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() +template +std::tuple> PackedConvWeightsOnednn< + kSpatialDim>::unpack() { + return std::tuple>( + orig_weight_, orig_bias_); +} + +template std::tuple> PackedConvWeightsOnednn< + 2>::unpack(); +template std::tuple> PackedConvWeightsOnednn< + 3>::unpack(); +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { @@ -154,6 +169,12 @@ class QConvUnpackWeightsInt8 final { } #endif +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return packed_weight->unpack(); + } +#endif + TORCH_CHECK( false, "Didn't find engine for operation quantized::conv2d_unpack ", @@ -185,6 +206,15 @@ class QConv1dUnpackWeightsInt8 final { } #endif +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + std::tie(weight, bias) = packed_weight->unpack(); + at::Tensor new_weight = weight.clone(); + new_weight.squeeze_(quant_utils::kConv1dSqueezeDim + 2); + return std::tuple>(new_weight, bias); + } +#endif + TORCH_CHECK( false, "Didn't find engine for operation quantized::conv1d_unpack ", diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 0e71423226685..d358f23c6af36 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -617,6 +618,81 @@ at::Tensor PackedLinearWeightsQnnp::apply_relu( #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() +template +at::Tensor PackedLinearWeightsOnednn::apply_impl( + at::Tensor input, + double output_scale, + int64_t output_zero_point) { + const int64_t dim = input.dim(); + TORCH_CHECK( + dim != 0, + "qlinear (ONEDNN): input dim should be at least 1, but got 0"); + TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8, + "qlinear (ONEDNN): data type of input should be QUint8."); + + auto input_contig = input.expect_contiguous(); + auto& w = *(weight_.get()); + auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1); + auto input_dims = {M, K}; + auto input_data_type = dnnl::memory::data_type::u8; + auto input_desc = ideep::tensor::desc(input_dims, input_data_type); + ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + ideep::tensor x(input_desc, input_contig->data_ptr()); + auto dst_dims = {M, N}; + const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input.q_scale()); + const ideep::scale_t& weights_scales = w.get_scale(); + const ideep::scale_t& dst_scales = ideep::scale_t(1, 1.0/output_scale); // Scales of ONEDNN and PyTorch are reciprocal + const ideep::zero_point_t& src_zero_point = ideep::zero_point_t(1, input.q_zero_point()); + const ideep::zero_point_t& dst_zero_point = ideep::zero_point_t(1, output_zero_point); + // Compute: Use ideep::matmul_forward to support asymmetric quantization + // Allocate output Tensor + at::Tensor output = at::_empty_affine_quantized( + dst_dims, + at::device(c10::kCPU).dtype(c10::kQUInt8), + output_scale, + output_zero_point); + if (output.numel() == 0) { + return output; + } + ideep::tensor y({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}}, + output.data_ptr()); + if (bias_.has_value()) { + // Bias might be modified outside (e.g. by quantization bias correction). + // If so, update the prepacked bias as well. + if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) { + bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr()); + } + const auto& b = bias_.value(); + ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales, dst_scales, + src_zero_point, dst_zero_point, op_attr); + } else { + ideep::matmul_forward::compute_v2(x, w, y, 1.0f, 1.0f, src_scales, weights_scales, dst_scales, + src_zero_point, dst_zero_point, op_attr); + } + auto out_sizes = input.sizes().vec(); + out_sizes.back() = N; + if (output.sizes().vec() == out_sizes) + return output; + return output.reshape(out_sizes); +} + +at::Tensor PackedLinearWeightsOnednn::apply( + at::Tensor input, + double output_scale, + int64_t output_zero_point) { + return apply_impl(std::move(input), output_scale, output_zero_point); +} + +at::Tensor PackedLinearWeightsOnednn::apply_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point) { + return apply_impl(std::move(input), output_scale, output_zero_point); +} + +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 81e5e1b721042..111255726dcf8 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -463,6 +464,99 @@ void PackedLinearWeightFp16::set_bias(c10::optional bias) { #endif // USE_FBGEMM +#if AT_MKLDNN_ENABLED() +template +at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl( + at::Tensor input, + bool reduce_range) { + // Dynamic: fp32 * int8 -> fp32 + using at::Tensor; + + TORCH_CHECK( + input.dim() >= 2, + "The dimension of input tensor should be larger than or equal to 2"); + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, + "qlinear_dynamic (ONEDNN): data type of input should be float."); + + // Input -> uint8 + auto input_contig = input.contiguous(); + const int64_t dim = input.dim(); + auto input_reshaped = + dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); + auto input_dims = input_reshaped.sizes().vec(); + auto input_data_type = dnnl::memory::data_type::f32; + auto input_desc = ideep::tensor::desc(input_dims, input_data_type); + ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + ideep::tensor x; + x.init(input_desc, input_contig.data_ptr()); + // Find quantization parameters + float x_max = 0, x_min = 0; + if (input.numel() > 0) { + x_min = input_contig.min().item(); + x_max = input_contig.max().item(); + } + const int precision = 8; + auto q_params = quant_utils::ChooseQuantizationParams( + /*min=*/x_min, + /*max=*/x_max, + /*qmin=*/0, + /*qmax=*/(1 << precision) - 1, + /*preserve_sparsity=*/false, + /*force_scale_power_of_two=*/false, + /*reduce_range=*/reduce_range); + const std::vector& src_zero_point = std::vector(1, q_params.zero_point); + // weights, dst + auto w = *(weight_.get()); + auto dst_dims = {x.get_dim(0), w.get_dim(1)}; + const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale); + const ideep::scale_t& weights_scales = w.get_scale(); + // Compute -> f32 + // Use ideep::matmul_forward instead of ideep::inner_product_forward, + // since the latter does not support asymmetric quantization + // Allocate output Tensor + at::Tensor output = at::empty(dst_dims, input.options().dtype(at::kFloat)); + if (output.numel() == 0) return output; + ideep::tensor y({dst_dims, ideep::tensor::data_type::f32, + {output.strides().cbegin(), output.strides().cend()}}, + output.data_ptr()); + if (bias_.has_value()) { + // Bias might be modified outside (e.g. by quantization bias correction). + // If so, update the prepacked bias as well. + if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) { + bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr()); + } + const ideep::tensor b = bias_.value(); + ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f, + src_scales, weights_scales, ideep::scale_t(), + src_zero_point, ideep::zero_point_t(), op_attr); + } else { + ideep::matmul_forward::compute_v2(x, w, y, 1.0f, 1.0f, + src_scales, weights_scales, ideep::scale_t(), + src_zero_point, ideep::zero_point_t(), op_attr); + } + auto out_sizes = input.sizes().vec(); + out_sizes.back() = w.get_dim(1); + if (output.sizes().vec() == out_sizes) + return output; + return output.reshape(out_sizes); +} + +at::Tensor PackedLinearWeightsOnednn::apply_dynamic( + at::Tensor input, + bool reduce_range) { + return apply_dynamic_impl( + std::move(input), reduce_range); +} + +at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu( + at::Tensor input, + bool reduce_range) { + return apply_dynamic_impl( + std::move(input), reduce_range); +} + +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index e88fb9d7009d9..22f0e3d595005 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -194,6 +195,80 @@ c10::intrusive_ptr PackedLinearWeightFp16::prepack( } #endif // USE_FBGEMM +#if AT_MKLDNN_ENABLED() +c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( + at::Tensor weight, + c10::optional bias) { + TORCH_CHECK( + weight.dim() == 2, + "The weight tensor for quantized::linear_prepack (onednn) should" + " be 2-dimensional."); + // Weight + std::vector dims = weight.sizes().vec(); + auto N = weight.size(0); + std::vector wgt_zero_points; + ideep::scale_t wgt_scales; + const auto qtype = weight.qscheme(); + if (qtype == c10::kPerTensorAffine) { + TORCH_CHECK( + weight.q_zero_point() == 0, + "quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight," + " whose zero point must be 0, but got ", weight.q_zero_point()); + wgt_zero_points = std::vector(1, weight.q_zero_point()); + wgt_scales = ideep::scale_t(1, 1.0/weight.q_scale()); // Scales of ONEDNN and PyTorch are reciprocal + } else if (qtype == c10::kPerChannelAffine) { + wgt_zero_points.resize(N); + wgt_scales.resize(N); + for (int i = 0; i < N; ++i) { + wgt_zero_points[i] = weight.q_per_channel_zero_points()[i].item(); + TORCH_CHECK( + wgt_zero_points[i] == 0, + "quantized::linear_prepack: ONEDNN only supports symmetric quantization of weight," + " whose zero point must be 0, but got ", wgt_zero_points[i], ", at index ", i); + wgt_scales[i] = 1.0f / weight.q_per_channel_scales()[i].item(); // Scales of ONEDNN and PyTorch are reciprocal + } + } else { + TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype)); + } + + // Prepack weight + auto weight_copy = weight.clone(); + ideep::tensor wgt = ideep::tensor({dims, dnnl::memory::data_type::s8}, weight_copy.data_ptr()); + wgt.transpose_(0, 1); // ONEDNN requires transposed weight + auto w_desc = ideep::matmul_forward::expected_weights_desc(wgt.get_dims(), dnnl::memory::data_type::s8, + dnnl::memory::data_type::u8); + ideep::tensor exp_wgt(w_desc); + exp_wgt.feed_from(wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + packed_weight_p->set_scale(wgt_scales); + packed_weight_p->set_zero_point(wgt_zero_points); + std::unique_ptr weight_ptr(packed_weight_p); + // Bias + c10::optional onednn_bias{c10::nullopt}; + if (bias.has_value()) { + auto& b = bias.value(); + auto bias_size = b.sizes().vec(); + bias_size.insert(bias_size.begin(), 1); + TORCH_CHECK( + bias_size[1] == weight_ptr->get_dim(1), + "bias should have N elements: ", + std::to_string(weight_ptr->get_dim(1)), + ", but got ", bias_size[1]); + auto bias_desc = ideep::tensor::desc(bias_size, dnnl::memory::data_type::f32); + ideep::tensor packed_bias; + packed_bias.init(bias_desc, b.data_ptr()); + onednn_bias = c10::optional(packed_bias); + } + auto ret_ptr = c10::make_intrusive( + PackedLinearWeightsOnednn{ + std::move(weight_ptr), + onednn_bias, + weight, + bias}); + return ret_ptr; +} +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { @@ -224,6 +299,11 @@ class QLinearPackWeightInt8 final { std::move(weight), std::move(bias)); } #endif +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias)); + } +#endif // #if AT_MKLDNN_ENABLED() TORCH_CHECK( false, "Didn't find engine for operation quantized::linear_prepack ", @@ -254,6 +334,14 @@ class QLinearPackWeightFp16 final { "not supported by QNNPACK"); } #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + TORCH_CHECK( + false, + "quantized::linear_prepack_fp16 is currently " + "not supported by ONEDNN"); + } +#endif // #if AT_MKLDNN_ENABLED() TORCH_CHECK( false, "Didn't find engine for operation quantized::linear_prepack_fp16 ", @@ -287,6 +375,16 @@ class QLinearPackWeightInt8Legacy final { return cpp_custom_type_hack::create(std::move(wrapped), options); } #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + auto prepacked = + PackedLinearWeightsOnednn::prepack(std::move(weight), std::move(bias)); + auto wrapped = + std::make_unique>( + std::move(prepacked)); + return cpp_custom_type_hack::create(std::move(wrapped), options); + } +#endif // #if AT_MKLDNN_ENABLED() TORCH_CHECK( false, "Didn't find engine for operation quantized::linear_prepack ", @@ -317,6 +415,14 @@ class QLinearPackWeightFp16Legacy final { "not supported by QNNPACK"); } #endif // USE_PYTORCH_QNNPACK +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + TORCH_CHECK( + false, + "quantized::linear_prepack_fp16 is currently " + "not supported by ONEDNN"); + } +#endif // #if AT_MKLDNN_ENABLED() TORCH_CHECK( false, "Didn't find engine for operation quantized::linear_prepack_fp16 ", diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp index 00baaf6fa75ed..dd257b80ea763 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -74,6 +75,13 @@ std::tuple> PackedLinearWeightFp16:: } #endif // USE_FBGEMM +#if AT_MKLDNN_ENABLED() +std::tuple> PackedLinearWeightsOnednn::unpack() { + return std::tuple>( + orig_weight_, orig_bias_); +} +#endif // #if AT_MKLDNN_ENABLED() + namespace at { namespace native { namespace { diff --git a/c10/core/QEngine.h b/c10/core/QEngine.h index ac092193d9213..60c21361f15f0 100644 --- a/c10/core/QEngine.h +++ b/c10/core/QEngine.h @@ -15,11 +15,13 @@ enum class QEngine : uint8_t { NoQEngine = 0, FBGEMM = 1, QNNPACK = 2, + ONEDNN = 3, }; constexpr auto kNoQEngine = QEngine::NoQEngine; constexpr auto kFBGEMM = QEngine::FBGEMM; constexpr auto kQNNPACK = QEngine::QNNPACK; +constexpr auto kONEDNN = QEngine::ONEDNN; inline std::string toString(QEngine qengine) { switch (qengine) { @@ -29,6 +31,8 @@ inline std::string toString(QEngine qengine) { return "FBGEMM"; case kQNNPACK: return "QNNPACK"; + case kONEDNN: + return "ONEDNN"; default: TORCH_CHECK( false, "Unrecognized Quantized Engine: ", static_cast(qengine)); diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index 8deec46b4188c..04a9343459997 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -22,6 +22,7 @@ override_qengines, qengine_is_qnnpack, qengine_is_fbgemm, + qengine_is_onednn, ) # TODO: Once more test files are created, move the contents to a ao folder. @@ -48,6 +49,9 @@ def test_sparse_qlinear(self): # to other higher priority works. if qengine_is_qnnpack() and not (row_block_size == 1 and col_block_size == 4): return + # ONEDNN does not support this yet + if qengine_is_onednn(): + return dense_prepack = torch.ops.quantized.linear_prepack dense_qlinear = torch.ops.quantized.linear @@ -215,6 +219,10 @@ def test_sparse_qlinear(self): Y_hat = sqmodel(X_fp32) self.assertEqual(Y_ref, Y_hat) + # ONEDNN does not support this yet + elif qengine_is_onednn(): + return + row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[2:] assert row_block_size == 1 and col_block_size == 4 diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index c1c409d319779..7cbab3be475e1 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -27,6 +27,7 @@ override_quantized_engine, override_qengines, qengine_is_qnnpack, + qengine_is_onednn, ) from hypothesis import assume, given from hypothesis import strategies as st @@ -99,7 +100,9 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias, zero_points=zero_point_tensor, axis=0, dtype=torch.qint8) else: - W_q = torch.quantize_per_tensor(W, 0.1, 4, torch.qint8) + # ONEDNN only supports symmetric quantization of weight + W_zp = 0 if qengine_is_onednn() else 4 + W_q = torch.quantize_per_tensor(W, 0.1, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) @@ -434,7 +437,7 @@ def test_conv1d_api(self): X_scale = 1.3 X_zero_point = 2 W_scale = [0.5] - W_zero_point = [3] + W_zero_point = [0] if qengine_is_onednn() else [3] Y_scale = 5.0 Y_zero_point = 4 if torch.backends.quantized.engine == 'qnnpack': @@ -501,7 +504,7 @@ def test_conv2d_api(self): X_scale = 1.3 X_zero_point = 2 W_scale = [0.5] - W_zero_point = [3] + W_zero_point = [0] if qengine_is_onednn() else [3] Y_scale = 5.0 Y_zero_point = 4 # use_fused -> quantized class @@ -570,7 +573,7 @@ def test_conv3d_api(self): X_scale = 1.3 X_zero_point = 2 W_scale = [0.5] - W_zero_point = [3] + W_zero_point = [0] if qengine_is_onednn() else [3] Y_scale = 5.0 Y_zero_point = 4 # use_fused -> quantized class @@ -1200,7 +1203,8 @@ def test_dynamic_convtranspose3d(self): def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): """test API functionality for nn.quantized.dynamic.Linear""" W = torch.rand(out_features, in_features).float() - W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8) + qscheme = torch.per_tensor_symmetric if qengine_is_onednn() else torch.per_tensor_affine + W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8, qscheme=qscheme) W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8) X = torch.rand(batch_size, in_features).float() B = torch.rand(out_features).float() if use_bias else None @@ -1311,8 +1315,8 @@ def test_lstm_api(self, dtype, bidirectional): bias_keys.append(key_name1) bias_keys.append(key_name2) - if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"): - # fp16 dynamic quant is not supported for qnnpack + if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): + # fp16 dynamic quant is not supported for qnnpack or onednn x = torch.randn(seq_len, batch, input_size) h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) @@ -1362,8 +1366,8 @@ def test_gru_api(self): # instantiated for all engines and dtypes for dtype in [torch.qint8, torch.float16]: - if dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack": - # fp16 dynamic quant is not supported for qnnpack + if dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn"): + # fp16 dynamic quant is not supported for qnnpack or onednn continue # Test default instantiation seq_len = 4 @@ -1435,8 +1439,8 @@ def test_cell_api(self, dtype): 'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic} for rnn_type in cell_dict.keys(): - if not (dtype == torch.float16 and torch.backends.quantized.engine == "qnnpack"): - # fp16 dynamic quant is not supported for qnnpack + if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): + # fp16 dynamic quant is not supported for qnnpack or onednn kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype} if rnn_type == 'RNNReLU': kwargs['nonlinearity'] = "relu" diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index b4ff538786133..6c32cec66d049 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -26,7 +26,10 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ override_quantized_engine, supported_qengines, override_qengines, _snr -from torch.testing._internal.common_quantized import qengine_is_qnnpack +from torch.testing._internal.common_quantized import ( + qengine_is_qnnpack, + qengine_is_onednn, +) from torch.ao.quantization import PerChannelMinMaxObserver from torch.testing._internal.common_cuda import TEST_CUDNN import torch.backends.xnnpack @@ -2658,7 +2661,7 @@ def forward( ] q_data = [] - reduce_range = (qengine == 'fbgemm') + reduce_range = (qengine in ('fbgemm', 'onednn')) for idx, x in enumerate(fp_data): scale, zero_point = _calculate_dynamic_qparams( x, dtype=dtype, reduce_range=reduce_range) @@ -2679,7 +2682,13 @@ def forward( mha.eval() # Prepare - mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine) + if qengine_is_onednn(): + # `reduce_range` is False by default for ONEDNN backend + # but the test fails on earlier CPUs without VNNI. + # So we use a default qconfig with `reduce_range=True` here + mha.qconfig = torch.ao.quantization.get_default_qconfig() + else: + mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine) mha_prepared = torch.ao.quantization.prepare( mha, prepare_custom_config_dict=custom_module_config) @@ -2772,7 +2781,7 @@ def test_qlinear(self, batch_size, input_channels, output_channels, (b_value_max - b_value_min) + b_value_min ).astype(np.int32) if use_bias else None - if torch.backends.quantized.engine == 'fbgemm': + if torch.backends.quantized.engine in ('fbgemm', 'onednn'): avoid_vpmaddubsw_overflow_linear( batch_size, input_channels, @@ -3009,8 +3018,8 @@ def test_qlstmGRU(self, num_batches, input_size, hidden_size, for rnn_type in ['LSTM', 'GRU']: for dtype in [torch.qint8, torch.float16]: - # Fp16 quantization is not supported for qnnpack - if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16: + # Fp16 quantization is not supported for qnnpack or onednn + if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16: continue if torch.backends.quantized.engine == 'qnnpack': @@ -3143,8 +3152,8 @@ def test_qrnncell(self, num_batches, input_size, hidden_size, per_channel_quant) for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']: for dtype in [torch.qint8, torch.float16]: - # Fp16 quantization is not supported for qnnpack - if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16: + # Fp16 quantization is not supported for qnnpack or onednn + if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16: continue if torch.backends.quantized.engine == 'qnnpack': @@ -3299,7 +3308,8 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, for dtype in dtypes: # No support for channelwise in xnnpack (int8) - if dtype == torch.qint8 and use_channelwise: + # ONEDNN does not support qint8 + if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()): return nptype = np_dtype[dtype] @@ -3322,7 +3332,8 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, W_scales = np.random.rand(output_channels) # xnnpack forces W_zp to 0 when using symmetric quantization - if dtype == torch.qint8: + # ONEDNN only supports symmetric quantization of weight + if dtype == torch.qint8 or qengine_is_onednn(): W_zps = np.zeros(output_channels).astype(np.int) else: W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(np.int) @@ -3342,7 +3353,7 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, np.random.rand(output_channels) * (b_value_max - b_value_min) + b_value_min ).astype(np.int32) if use_bias else None - if torch.backends.quantized.engine == 'fbgemm': + if torch.backends.quantized.engine in ('fbgemm', 'onednn'): avoid_vpmaddubsw_overflow_linear( batch_size, input_channels, @@ -3429,6 +3440,13 @@ def test_qlinear_unpack(self, W, use_channelwise): qlinear_prepack = torch.ops.quantized.linear_prepack qlinear_unpack = torch.ops.quantized.linear_unpack + # ONEDNN only supports symmetric quantization of weight + if qengine_is_onednn(): + if use_channelwise: + W_zps = torch.zeros(output_channels).to(torch.int64) + else: + W_zp = 0 + W = torch.from_numpy(W) if use_channelwise: W_q = torch.quantize_per_channel( @@ -3892,6 +3910,10 @@ def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs, if channelwise and transposed: # currently transposed conv and per-channel per quantization does not work return + # ONEDNN only supports symmetric quantization of weight and zero output padding + if qengine_is_onednn(): + W_zero_point = 0 + o_pads = len(o_pads) * [0] if o_pads is not None else None if channelwise: if transposed: output_channels = W.shape[1] # IC OC/G @@ -4030,6 +4052,9 @@ def _test_qconv_impl( weight_dtype=torch.qint8, output_dtype=torch.quint8, ): + # ONEDNN only supports symmetric quantization of weight + if qengine_is_onednn() and W_zero_point is not None: + W_zero_point = len(W_zero_point) * [0] (X, W), (X_q, W_q), bias_float = self._make_qconv_tensors( batch_size, input_channels_per_group, input_feature_map_shape, output_channels_per_group, groups, kernels, @@ -4512,6 +4537,9 @@ def test_qconv_transpose2d( use_bias): if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN): return # QNNPACK doesn't support these + # ONEDNN does not support output paddings + if qengine_is_onednn() and (o_pad_h, o_pad_w) != (0, 0): + return assume(o_pad_h < stride_h and o_pad_h < dilation) assume(o_pad_w < stride_w and o_pad_w < dilation) @@ -4641,6 +4669,9 @@ def test_qconv_transpose3d( use_bias): if qengine_is_qnnpack(): return # QNNPACK doesn't support this + # ONEDNN doesn't support output paddings + if qengine_is_onednn() and (o_pad_t, o_pad_h, o_pad_w) != (0, 0, 0): + return assume(o_pad_t < stride_t or o_pad_t < dilation) assume(o_pad_h < stride_h or o_pad_h < dilation) assume(o_pad_w < stride_w or o_pad_w < dilation) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index de912675adac0..512c8f10f851b 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -184,8 +184,8 @@ def get_default_qconfig(backend='fbgemm', version=0): Returns the default PTQ qconfig for the specified backend. Args: - * `backend`: a string representing the target backend. Currently supports `fbgemm` - and `qnnpack`. + * `backend`: a string representing the target backend. Currently supports `fbgemm`, + `qnnpack` and `onednn`. Return: qconfig @@ -197,6 +197,9 @@ def get_default_qconfig(backend='fbgemm', version=0): elif backend == 'qnnpack': qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False), weight=default_weight_observer) + elif backend == 'onednn': + qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False), + weight=default_per_channel_weight_observer) else: qconfig = default_qconfig else: @@ -216,8 +219,8 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): Returns the default QAT qconfig for the specified backend. Args: - * `backend`: a string representing the target backend. Currently supports `fbgemm` - and `qnnpack`. + * `backend`: a string representing the target backend. Currently supports `fbgemm`, + `qnnpack` and `onednn`. * `version`: version, for backwards compatibility. Can be `None` or `1`. Return: @@ -237,6 +240,11 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_max=255, reduce_range=False), weight=default_weight_fake_quant) + elif backend == 'onednn': + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255), + weight=default_per_channel_weight_fake_quant) else: qconfig = default_qat_qconfig # Use the fused observe + fake_quant modules for doing QAT. @@ -253,6 +261,11 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_max=255, reduce_range=False), weight=default_fused_wt_fake_quant) + elif backend == 'onednn': + qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255), + weight=default_fused_per_channel_wt_fake_quant) else: qconfig = default_qat_qconfig_v2 else: diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py index a24d88bcc6e6d..6f7d479e90c4a 100644 --- a/torch/backends/quantized/__init__.py +++ b/torch/backends/quantized/__init__.py @@ -11,6 +11,8 @@ def _get_qengine_id(qengine: str) -> int: ret = 1 elif qengine == 'qnnpack': ret = 2 + elif qengine == 'onednn': + ret = 3 else: ret = -1 raise RuntimeError("{} is not a valid value for quantized engine".format(qengine)) @@ -18,7 +20,7 @@ def _get_qengine_id(qengine: str) -> int: # This function should correspond to the enums present in c10/core/QEngine.h def _get_qengine_str(qengine: int) -> str: - all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'} + all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack', 3 : 'onednn'} return all_engines.get(qengine, '*undefined') class _QEngineProp(object): diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 9440e825e411d..597fd774e3299 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -46,9 +46,12 @@ def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8): qx = np.clip(qx, qmin, qmax).astype(qtype) return qx -def _calculate_dynamic_qparams(X, dtype, reduce_range=False): +def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine): """Calculate the dynamic quantization parameters (scale, zero_point) according to the min and max element of the tensor""" + assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric) + if qscheme == torch.per_tensor_symmetric: + assert dtype == torch.qint8 if isinstance(X, torch.Tensor): X = X.numpy() if dtype == torch.qint8: @@ -63,17 +66,25 @@ def _calculate_dynamic_qparams(X, dtype, reduce_range=False): qmin, qmax = 0, 255 min_val = X.min() max_val = X.max() + is_symmetric = (qscheme == torch.per_tensor_symmetric) if min_val == max_val: scale = 1.0 zero_point = 0 else: - max_val = max(max_val, 0.0) - min_val = min(min_val, 0.0) - scale = (max_val - min_val) / (qmax - qmin) - scale = max(scale, np.finfo(np.float32).eps) - zero_point = qmin - round(min_val / scale) - zero_point = max(qmin, zero_point) - zero_point = min(qmax, zero_point) + if is_symmetric: + max_val = max(max_val, -min_val) + min_val = -max_val + scale = (max_val - min_val) / (qmax - qmin) + scale = max(scale, np.finfo(np.float32).eps) + zero_point = 0 + else: + max_val = max(max_val, 0.0) + min_val = min(min_val, 0.0) + scale = (max_val - min_val) / (qmax - qmin) + scale = max(scale, np.finfo(np.float32).eps) + zero_point = qmin - round(min_val / scale) + zero_point = max(qmin, zero_point) + zero_point = min(qmax, zero_point) return [float(scale), int(zero_point)] def _calculate_dynamic_per_channel_qparams(X, dtype): @@ -165,6 +176,8 @@ def qengine_is_fbgemm(): return torch.backends.quantized.engine == 'fbgemm' def qengine_is_qnnpack(): return torch.backends.quantized.engine == 'qnnpack' +def qengine_is_onednn(): + return torch.backends.quantized.engine == 'onednn' # Helper function used to simulate per-channel fake-quant against any axis def _permute_to_axis_zero(X, axis):