Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use MLAS for QGEMM in matmulInteger and convInteger #1692

Merged
merged 3 commits into from
Aug 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ option(onnxruntime_USE_EIGEN_FOR_BLAS "Use eign for blas" ON)
option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" OFF)
option(onnxruntime_USE_MKLDNN "Build with MKL-DNN support" OFF)
option(onnxruntime_USE_MKLML "Build MKL-DNN with MKL-ML binary dependency" OFF)
option(onnxruntime_USE_GEMMLOWP "Build with gemmlowp for quantized gemm" OFF)
option(onnxruntime_USE_AUTOML "Build AutoML support" ON)
option(onnxruntime_USE_NGRAPH "Build with nGraph support" OFF)
option(onnxruntime_USE_OPENBLAS "Use openblas" OFF)
Expand Down Expand Up @@ -472,6 +473,10 @@ if (onnxruntime_USE_MKLDNN OR onnxruntime_USE_MKLML)
include(mkldnn)
endif()

if(onnxruntime_USE_GEMMLOWP)
add_definitions(-DUSE_GEMMLOWP=1)
endif()

if (onnxruntime_USE_MKLML)
add_definitions(-DUSE_MKLML=1 -DUSE_MKLML_FOR_BLAS=1)
if (WIN32 OR APPLE)
Expand Down
32 changes: 18 additions & 14 deletions onnxruntime/core/providers/cpu/math/matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "core/providers/cpu/math/matmul_integer.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/gemmlowp_common.h"
#include "core/util/qmath.h"
#include "core/providers/common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -31,30 +31,34 @@ Status MatMulInteger<uint8_t, uint8_t, int32_t>::Compute(OpKernelContext* ctx) c
Tensor* y = ctx->Output(0, helper.OutputShape());

// validate zero points
int32_t a_offset = 0;
int32_t b_offset = 0;
uint8_t a_offset = 0;
uint8_t b_offset = 0;
if (has_a_zero_point_) {
auto a_zero_point = ctx->Input<Tensor>(2);
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
a_offset = static_cast<int32_t>(*a_zero_point->template Data<uint8_t>());
}
if (has_b_zero_point_) {
auto b_zero_point = ctx->Input<Tensor>(3);
ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point),
"MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1");
"MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1");
b_offset = static_cast<int32_t>(*b_zero_point->template Data<uint8_t>());
}

for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
GemmlowpMultiplyu8u8_s32(a->template Data<uint8_t>() + helper.LeftOffsets()[i],
b->template Data<uint8_t>() + helper.RightOffsets()[i],
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
a_offset,
b_offset,
static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()));
QGemmu8u8_s32(static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
a->template Data<uint8_t>() + helper.LeftOffsets()[i],
static_cast<int>(helper.K()),
a_offset,
b->template Data<uint8_t>() + helper.RightOffsets()[i],
static_cast<int>(helper.N()),
b_offset,
y->template MutableData<int32_t>() + helper.OutputOffsets()[i],
static_cast<int>(helper.N()),
nullptr);
}

return Status::OK();
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/cpu/nn/conv_integer.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.


#include "core/providers/cpu/nn/conv_integer.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "core/util/gemmlowp_common.h"
#include "core/util/qmath.h"
#include "core/providers/common.h"

namespace onnxruntime {
Expand All @@ -22,11 +21,12 @@ ONNX_OPERATOR_KERNEL_EX(
ConvInteger);

Status ConvInteger::Compute(OpKernelContext* context) const {

size_t num_inputs = OpKernel::Node().InputDefs().size();
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
uint8_t input_offset = 0;
uint8_t filter_offset = 0;
uint8_t filter_offset = 0;
if (num_inputs >= 3) {
const auto* X_Zero_Point = context->Input<Tensor>(2);
ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1.");
Expand All @@ -35,7 +35,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
if (num_inputs >= 4) {
const auto* W_Zero_Point = context->Input<Tensor>(3);
ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now.");
filter_offset = *(W_Zero_Point->Data<uint8_t>());
filter_offset = *(W_Zero_Point->Data<uint8_t>());
}

const int64_t N = X->Shape()[0];
Expand Down Expand Up @@ -108,14 +108,18 @@ Status ConvInteger::Compute(OpKernelContext* context) const {
false,
input_offset);

GemmlowpMultiplyu8u8_s32(W->template Data<uint8_t>() + group_id * W_offset,
col_buffer_data,
Ydata + group_id * Y_offset,
filter_offset,
input_offset,
static_cast<int>(M / group_),
static_cast<int>(output_image_size),
static_cast<int>(kernel_dim));
QGemmu8u8_s32(static_cast<int>(M / group_),
static_cast<int>(output_image_size),
static_cast<int>(kernel_dim),
W->template Data<uint8_t>() + group_id * W_offset,
static_cast<int>(kernel_dim),
filter_offset,
col_buffer_data,
static_cast<int>(output_image_size),
input_offset,
Ydata + group_id * Y_offset,
static_cast<int>(output_image_size),
nullptr);
}

Xdata += X_offset * group_;
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/util/gemmlowp_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace onnxruntime {

Status GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
const int lhs_offset, const int rhs_offset, const int result_offset,
int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias) {
// TODO exp ColMajor order for rhs and result. That may be faster
Expand All @@ -28,13 +28,11 @@ Status GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data,
gemmlowp::DefaultL8R8BitDepthParams>(
&gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, output_pipeline);
}

return Status::OK();
}

Status GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data,
const int lhs_offset, const int rhs_offset, int m, int n, int k) {
// TODO exp ColMajor order for rhs and result. That may be faster
void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data,
const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool* ) {

const auto matOrder = gemmlowp::MapOrder::RowMajor;
gemmlowp::MatrixMap<const uint8_t, matOrder> lhs(lhs_data, m, k);
gemmlowp::MatrixMap<const uint8_t, matOrder> rhs(rhs_data, k, n);
Expand All @@ -47,6 +45,5 @@ Status GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, gemmlowp::DefaultL8R8BitDepthParams>(
&gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, empty_pipeline);

return Status::OK();
}
}
9 changes: 4 additions & 5 deletions onnxruntime/core/util/gemmlowp_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include "core/util/gemmlowp_common_wrapper.h"
#include "core/util/math.h"

#include "core/platform/threadpool.h"

namespace onnxruntime {

Expand Down Expand Up @@ -56,11 +55,11 @@ MakeOutputPipelineWithOutBias(std::int32_t result_offset,
return std::make_tuple(quantize_down_stage, saturating_cast_stage);
}

Status GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data,
const int lhs_offset, const int rhs_offset, const int result_offset,
int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias = nullptr);

Status GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data,
const int lhs_offset, const int rhs_offset, int m, int n, int k);
void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data,
const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool*);

} // namespace onnxruntime
33 changes: 33 additions & 0 deletions onnxruntime/core/util/qmath.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/util/qmath.h"
#include "core/common/common.h"

namespace onnxruntime {

void QGemmu8u8_s32(
int M,
int N,
int K,
const uint8_t* lhs_data,
int lda,
const uint8_t lhs_offset,
const uint8_t* rhs_data,
int ldb,
const uint8_t rhs_offset,
int32_t* result_data,
int ldc,
concurrency::ThreadPool* thread_pool) {
#ifdef USE_GEMMLOWP

ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported");

GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool);

#else
MlasQgemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool);

#endif
}
} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/core/util/qmath.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once
//#define USE_GEMMLOWP

#ifdef USE_GEMMLOWP
#include "core/util/gemmlowp_common.h"
#else
#include "core/mlas/inc/mlas.h"
#endif
#include "core/platform/threadpool.h"
#include <mutex>
#include <thread>

namespace onnxruntime {

void QGemmu8u8_s32(
int M,
int N,
int K,
const uint8_t* lhs_data,
int lda,
const uint8_t lhs_offset,
const uint8_t* rhs_data,
int ldb,
const uint8_t rhs_offset,
int32_t* result_data,
int ldc,
concurrency::ThreadPool* thread_pool);

} // namespace onnxruntime
14 changes: 12 additions & 2 deletions onnxruntime/test/providers/cpu/math/matmul_integer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
namespace onnxruntime {
namespace test {

TEST(MatmulIntegerOpTest, MatMulInteger1) {
TEST(MatmulIntegerOpTest, MatMulInteger_2D) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0});
test.AddInput<uint8_t>("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0});
test.AddInput<uint8_t>("T2", {3, 2}, {1, 4, 2, 5, 3, 6});
test.AddInput<uint8_t>("a_zero_point", {}, {12});
test.AddInput<uint8_t>("b_zero_point", {}, {0});
Expand All @@ -30,5 +30,15 @@ TEST(MatmulIntegerOpTest, MatMulInteger) {
test.AddOutput<int32_t>("T3", {1, 1}, {-1});
test.Run();
}
TEST(MatmulIntegerOpTest, MatMulInteger_WithZero_ZeroPoint) {
OpTester test("MatMulInteger", 10);
test.AddInput<uint8_t>("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0});
test.AddInput<uint8_t>("T2", {3, 2}, {1, 4, 2, 5, 3, 6});
test.AddInput<uint8_t>("a_zero_point", {}, {0});
test.AddInput<uint8_t>("b_zero_point", {}, {0});
test.AddOutput<int32_t>("T3", {4, 2}, {34, 97, 28, 82, 22, 67, 16, 52});
test.Run();
}

} // namespace test
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def parse_arguments():
parser.add_argument("--use_openblas", action='store_true', help="Build with OpenBLAS.")
parser.add_argument("--use_mkldnn", action='store_true', help="Build with MKLDNN.")
parser.add_argument("--use_mklml", action='store_true', help="Build with MKLML.")
parser.add_argument("--use_gemmlowp", action='store_true', help="Build with gemmlowp for quantized gemm.")
parser.add_argument("--use_automl", action='store_true', help="Build with AutoML support.")
parser.add_argument("--use_ngraph", action='store_true', help="Build with nGraph.")
parser.add_argument("--use_openvino", nargs="?", const="CPU_FP32",
Expand Down Expand Up @@ -334,6 +335,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
"-Donnxruntime_USE_OPENBLAS=" + ("ON" if args.use_openblas else "OFF"),
"-Donnxruntime_USE_MKLDNN=" + ("ON" if args.use_mkldnn else "OFF"),
"-Donnxruntime_USE_MKLML=" + ("ON" if args.use_mklml else "OFF"),
"-Donnxruntime_USE_GEMMLOWP=" + ("ON" if args.use_gemmlowp else "OFF"),
"-Donnxruntime_USE_NGRAPH=" + ("ON" if args.use_ngraph else "OFF"),
"-Donnxruntime_USE_OPENVINO=" + ("ON" if args.use_openvino else "OFF"),
"-Donnxruntime_USE_OPENVINO_BINARY=" + ("ON" if args.use_openvino else "OFF"),
Expand Down