Skip to content

Commit

Permalink
Add onednn quant backend (pytorch#74137)
Browse files Browse the repository at this point in the history
Summary:
Resolve the conflicts in pytorch#69820
jerryzh168 Please review. Thanks.

Pull Request resolved: pytorch#74137

Reviewed By: samdow

Differential Revision: D34840477

Pulled By: jerryzh168

fbshipit-source-id: 8aa60981ff7be211a1609644f273b16d18efd425
(cherry picked from commit de76bb8)
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Mar 15, 2022
1 parent deae595 commit 060f1b8
Show file tree
Hide file tree
Showing 19 changed files with 1,007 additions and 37 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ const std::vector<at::QEngine>& Context::supportedQEngines() {
engines.push_back(at::kNoQEngine);
#endif // C10_MOBILE

#if AT_MKLDNN_ENABLED()
engines.push_back(at::kONEDNN);
#endif

#ifdef USE_FBGEMM
if (fbgemm::fbgemmSupportedCPU()) {
engines.push_back(at::kFBGEMM);
Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/native/quantized/cpu/conv_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/core/List.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <c10/util/irange.h>

#include <tuple>
Expand Down Expand Up @@ -358,6 +359,20 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
);
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
weight.value(),
bias,
stride,
padding,
output_padding,
dilation,
groups,
transpose
);
}
#endif // AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for when deserializing ConvPackedParams: ",
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/quantized/Quantizer.h>
Expand Down Expand Up @@ -470,6 +471,16 @@ int register_linear_params() {
std::move(weight), std::move(bias));
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
TORCH_CHECK(
weight.scalar_type() == at::kQInt8,
"ONEDNN only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type()));
return PackedLinearWeightsOnednn::prepack(
std::move(weight), std::move(bias));
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(false, "Unknown qengine");
})
.def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {
Expand Down
151 changes: 151 additions & 0 deletions aten/src/ATen/native/quantized/cpu/onednn_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#pragma once

#include <ATen/Config.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/Tensor.h>
#include <ATen/native/quantized/packed_params.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>

struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
PackedLinearWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
c10::optional<ideep::tensor> bias,
at::Tensor orig_weight,
c10::optional<at::Tensor> orig_bias)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)) {}
std::unique_ptr<ideep::tensor> weight_;
c10::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
c10::optional<at::Tensor> orig_bias_;

at::Tensor apply(
at::Tensor input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_relu(
at::Tensor input,
double output_scale,
int64_t output_zero_point) override;

at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;

std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;

c10::optional<at::Tensor> bias() override {
return orig_bias_;
}

static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias);

private:
template <bool ReluFused>
at::Tensor apply_impl(
at::Tensor input,
double output_scale,
int64_t output_zero_point);

template <bool ReluFused>
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
};

template <int kSpatialDim = 2>
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
c10::optional<ideep::tensor> bias,
at::Tensor orig_weight,
c10::optional<at::Tensor> orig_bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
uint8_t transpose)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)),
stride_(std::move(stride)),
padding_(std::move(padding)),
output_padding_(std::move(output_padding)),
dilation_(std::move(dilation)),
groups_(groups),
transpose_(transpose) {}

std::unique_ptr<ideep::tensor> weight_;
c10::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
c10::optional<at::Tensor> orig_bias_;
torch::List<int64_t> stride_;
torch::List<int64_t> padding_;
torch::List<int64_t> output_padding_;
torch::List<int64_t> dilation_;
int64_t groups_;
uint8_t transpose_;

at::Tensor apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;

at::Tensor apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;

at::Tensor apply_dynamic(
const at::Tensor& input,
bool reduce_range) override;

std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;

static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);

torch::List<int64_t> stride() const override {
return stride_;
}

torch::List<int64_t> padding() const override {
return padding_;
}

torch::List<int64_t> output_padding() const override {
return output_padding_;
}

torch::List<int64_t> dilation() const override {
return dilation_;
}

int64_t groups() const override {
return groups_;
}

bool transpose() const override {
return (bool)transpose_;
}

private:
template <bool ReluFused>
at::Tensor apply_impl(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point);
};

#endif // #if AT_MKLDNN_ENABLED()
Loading

0 comments on commit 060f1b8

Please sign in to comment.