From 0916c757619015686a43c98504e533ec24069294 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Mon, 26 Aug 2019 11:42:37 -0700 Subject: [PATCH 1/2] use mlas qgemm for u8u8_s32 gemms --- cmake/CMakeLists.txt | 5 +++ .../core/providers/cpu/math/matmul_integer.cc | 32 ++++++++++-------- .../core/providers/cpu/nn/conv_integer.cc | 28 +++++++++------- onnxruntime/core/util/gemmlowp_common.cc | 11 +++---- onnxruntime/core/util/gemmlowp_common.h | 9 +++-- onnxruntime/core/util/qmath.cc | 33 +++++++++++++++++++ onnxruntime/core/util/qmath.h | 29 ++++++++++++++++ .../providers/cpu/math/matmul_integer_test.cc | 21 +++++++++++- tools/ci_build/build.py | 2 ++ 9 files changed, 131 insertions(+), 39 deletions(-) create mode 100644 onnxruntime/core/util/qmath.cc create mode 100644 onnxruntime/core/util/qmath.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 9b084286b4c6d..b75b22967b026 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -52,6 +52,7 @@ option(onnxruntime_USE_EIGEN_FOR_BLAS "Use eign for blas" ON) option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" OFF) option(onnxruntime_USE_MKLDNN "Build with MKL-DNN support" OFF) option(onnxruntime_USE_MKLML "Build MKL-DNN with MKL-ML binary dependency" OFF) +option(onnxruntime_USE_GEMMLOWP "Build with gemmlowp for quantized gemm" OFF) option(onnxruntime_USE_AUTOML "Build AutoML support" ON) option(onnxruntime_USE_NGRAPH "Build with nGraph support" OFF) option(onnxruntime_USE_OPENBLAS "Use openblas" OFF) @@ -472,6 +473,10 @@ if (onnxruntime_USE_MKLDNN OR onnxruntime_USE_MKLML) include(mkldnn) endif() +if(onnxruntime_USE_GEMMLOWP) + add_definitions(-DUSE_GEMMLOWP=1) +endif() + if (onnxruntime_USE_MKLML) add_definitions(-DUSE_MKLML=1 -DUSE_MKLML_FOR_BLAS=1) if (WIN32 OR APPLE) diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer.cc b/onnxruntime/core/providers/cpu/math/matmul_integer.cc index 6e8464d4d5a94..5aed5b8c082d1 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cpu/math/matmul_integer.cc @@ -3,7 +3,7 @@ #include "core/providers/cpu/math/matmul_integer.h" #include "core/providers/cpu/math/matmul_helper.h" -#include "core/util/gemmlowp_common.h" +#include "core/util/qmath.h" #include "core/providers/common.h" namespace onnxruntime { @@ -31,30 +31,34 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) c Tensor* y = ctx->Output(0, helper.OutputShape()); // validate zero points - int32_t a_offset = 0; - int32_t b_offset = 0; + uint8_t a_offset = 0; + uint8_t b_offset = 0; if (has_a_zero_point_) { auto a_zero_point = ctx->Input(2); - ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point), - "MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point), + "MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1"); a_offset = static_cast(*a_zero_point->template Data()); } if (has_b_zero_point_) { auto b_zero_point = ctx->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(b_zero_point), - "MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1"); + "MatmulInteger : input2 zero point must be a scalar or 1D tensor of size 1"); b_offset = static_cast(*b_zero_point->template Data()); } for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { - GemmlowpMultiplyu8u8_s32(a->template Data() + helper.LeftOffsets()[i], - b->template Data() + helper.RightOffsets()[i], - y->template MutableData() + helper.OutputOffsets()[i], - a_offset, - b_offset, - static_cast(helper.M()), - static_cast(helper.N()), - static_cast(helper.K())); + QGemmu8u8_s32(static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K()), + a->template Data() + helper.LeftOffsets()[i], + static_cast(helper.K()), + a_offset, + b->template Data() + helper.RightOffsets()[i], + static_cast(helper.N()), + b_offset, + y->template MutableData() + helper.OutputOffsets()[i], + static_cast(helper.N()), + nullptr); } return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/nn/conv_integer.cc b/onnxruntime/core/providers/cpu/nn/conv_integer.cc index 49d2a3c575950..534cb75a6e840 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_integer.cc +++ b/onnxruntime/core/providers/cpu/nn/conv_integer.cc @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cpu/nn/conv_integer.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/util/gemmlowp_common.h" +#include "core/util/qmath.h" #include "core/providers/common.h" namespace onnxruntime { @@ -22,11 +21,12 @@ ONNX_OPERATOR_KERNEL_EX( ConvInteger); Status ConvInteger::Compute(OpKernelContext* context) const { + size_t num_inputs = OpKernel::Node().InputDefs().size(); const auto* X = context->Input(0); const auto* W = context->Input(1); uint8_t input_offset = 0; - uint8_t filter_offset = 0; + uint8_t filter_offset = 0; if (num_inputs >= 3) { const auto* X_Zero_Point = context->Input(2); ORT_ENFORCE(IsScalarOr1ElementVector(X_Zero_Point), "Must be a scalar or 1D tensor or size 1."); @@ -35,7 +35,7 @@ Status ConvInteger::Compute(OpKernelContext* context) const { if (num_inputs >= 4) { const auto* W_Zero_Point = context->Input(3); ORT_ENFORCE(IsScalarOr1ElementVector(W_Zero_Point), "Non per-tensor quantization is not supported now."); - filter_offset = *(W_Zero_Point->Data()); + filter_offset = *(W_Zero_Point->Data()); } const int64_t N = X->Shape()[0]; @@ -108,14 +108,18 @@ Status ConvInteger::Compute(OpKernelContext* context) const { false, input_offset); - GemmlowpMultiplyu8u8_s32(W->template Data() + group_id * W_offset, - col_buffer_data, - Ydata + group_id * Y_offset, - filter_offset, - input_offset, - static_cast(M / group_), - static_cast(output_image_size), - static_cast(kernel_dim)); + QGemmu8u8_s32(static_cast(M / group_), + static_cast(output_image_size), + static_cast(kernel_dim), + W->template Data() + group_id * W_offset, + static_cast(kernel_dim), + filter_offset, + col_buffer_data, + static_cast(output_image_size), + input_offset, + Ydata + group_id * Y_offset, + static_cast(output_image_size), + nullptr); } Xdata += X_offset * group_; diff --git a/onnxruntime/core/util/gemmlowp_common.cc b/onnxruntime/core/util/gemmlowp_common.cc index 7754e717a7dc5..7b14a09b510c5 100644 --- a/onnxruntime/core/util/gemmlowp_common.cc +++ b/onnxruntime/core/util/gemmlowp_common.cc @@ -6,7 +6,7 @@ namespace onnxruntime { -Status GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, +void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, const int lhs_offset, const int rhs_offset, const int result_offset, int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias) { // TODO exp ColMajor order for rhs and result. That may be faster @@ -28,13 +28,11 @@ Status GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, gemmlowp::DefaultL8R8BitDepthParams>( &gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, output_pipeline); } - - return Status::OK(); } -Status GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data, - const int lhs_offset, const int rhs_offset, int m, int n, int k) { - // TODO exp ColMajor order for rhs and result. That may be faster +void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data, + const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool* ) { + const auto matOrder = gemmlowp::MapOrder::RowMajor; gemmlowp::MatrixMap lhs(lhs_data, m, k); gemmlowp::MatrixMap rhs(rhs_data, k, n); @@ -47,6 +45,5 @@ Status GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data gemmlowp::GemmWithOutputPipeline( &gemm_context, lhs, rhs, &result, -lhs_offset, -rhs_offset, empty_pipeline); - return Status::OK(); } } \ No newline at end of file diff --git a/onnxruntime/core/util/gemmlowp_common.h b/onnxruntime/core/util/gemmlowp_common.h index 63fb187bdd9db..1744058d46942 100644 --- a/onnxruntime/core/util/gemmlowp_common.h +++ b/onnxruntime/core/util/gemmlowp_common.h @@ -2,8 +2,7 @@ // Licensed under the MIT License. #pragma once #include "core/util/gemmlowp_common_wrapper.h" -#include "core/util/math.h" - +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -56,11 +55,11 @@ MakeOutputPipelineWithOutBias(std::int32_t result_offset, return std::make_tuple(quantize_down_stage, saturating_cast_stage); } -Status GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, +void GemmlowpMultiplyu8u8_u8(const uint8_t* lhs_data, const uint8_t* rhs_data, uint8_t* result_data, const int lhs_offset, const int rhs_offset, const int result_offset, int m, int n, int k, int32_t int_multiplier, int32_t right_shift, const int32_t* bias = nullptr); -Status GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data, - const int lhs_offset, const int rhs_offset, int m, int n, int k); +void GemmlowpMultiplyu8u8_s32(const uint8_t* lhs_data, const uint8_t* rhs_data, int32_t* result_data, + const int lhs_offset, const int rhs_offset, int m, int n, int k, concurrency::ThreadPool*); } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/util/qmath.cc b/onnxruntime/core/util/qmath.cc new file mode 100644 index 0000000000000..9372ad29d89ff --- /dev/null +++ b/onnxruntime/core/util/qmath.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/util/qmath.h" +#include "core/common/common.h" + +namespace onnxruntime { + +void QGemmu8u8_s32( + int M, + int N, + int K, + const uint8_t* lhs_data, + int lda, + const uint8_t lhs_offset, + const uint8_t* rhs_data, + int ldb, + const uint8_t rhs_offset, + int32_t* result_data, + int ldc, + concurrency::ThreadPool* thread_pool) { +#ifdef USE_GEMMLOWP + + ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For gemmlowp only RowMajor*RowMajor=RowMajor format is supported"); + + GemmlowpMultiplyu8u8_s32(lhs_data, rhs_data, result_data, lhs_offset, rhs_offset, M, N, K, thread_pool); + +#else + MlasQgemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool); + +#endif +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h new file mode 100644 index 0000000000000..3cec9047f3abe --- /dev/null +++ b/onnxruntime/core/util/qmath.h @@ -0,0 +1,29 @@ +#pragma once +//#define USE_GEMMLOWP + +#ifdef USE_GEMMLOWP +#include "core/util/gemmlowp_common.h" +#else +#include "core/mlas/inc/mlas.h" +#endif +#include "core/platform/threadpool.h" +#include +#include + +namespace onnxruntime { + +void QGemmu8u8_s32( + int M, + int N, + int K, + const uint8_t* lhs_data, + int lda, + const uint8_t lhs_offset, + const uint8_t* rhs_data, + int ldb, + const uint8_t rhs_offset, + int32_t* result_data, + int ldc, + concurrency::ThreadPool* thread_pool); + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc index c197a7511b967..d84ecb8562b6c 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc @@ -13,7 +13,7 @@ namespace test { TEST(MatmulIntegerOpTest, MatMulInteger1) { OpTester test("MatMulInteger", 10); - test.AddInput("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); + test.AddInput("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); test.AddInput("T2", {3, 2}, {1, 4, 2, 5, 3, 6}); test.AddInput("a_zero_point", {}, {12}); test.AddInput("b_zero_point", {}, {0}); @@ -30,5 +30,24 @@ TEST(MatmulIntegerOpTest, MatMulInteger) { test.AddOutput("T3", {1, 1}, {-1}); test.Run(); } +TEST(MatmulIntegerOpTest, MatMulInteger12) { + OpTester test("MatMulInteger", 10); + test.AddInput("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); + test.AddInput("T2", {3, 2}, {1, 4, 2, 5, 3, 6}); + test.AddInput("a_zero_point", {}, {0}); + test.AddInput("b_zero_point", {}, {0}); + test.AddOutput("T3", {4, 2}, {34, 97, 28, 82, 22, 67, 16, 52}); + test.Run(); +} + +TEST(MatmulIntegerOpTest, MatMulInteger13) { + OpTester test("MatMulInteger", 10); + test.AddInput("T1", {1, 1}, {11}); + test.AddInput("T2", {1, 1}, {13}); + test.AddInput("a_zero_point", {}, {0}); + test.AddInput("b_zero_point", {}, {0}); + test.AddOutput("T3", {1, 1}, {143}); + test.Run(); +} } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0b25afcfac495..7278e6e5b14a8 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -127,6 +127,7 @@ def parse_arguments(): parser.add_argument("--use_openblas", action='store_true', help="Build with OpenBLAS.") parser.add_argument("--use_mkldnn", action='store_true', help="Build with MKLDNN.") parser.add_argument("--use_mklml", action='store_true', help="Build with MKLML.") + parser.add_argument("--use_gemmlowp", action='store_true', help="Build with gemmlowp for quantized gemm.") parser.add_argument("--use_automl", action='store_true', help="Build with AutoML support.") parser.add_argument("--use_ngraph", action='store_true', help="Build with nGraph.") parser.add_argument("--use_openvino", nargs="?", const="CPU_FP32", @@ -334,6 +335,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_USE_OPENBLAS=" + ("ON" if args.use_openblas else "OFF"), "-Donnxruntime_USE_MKLDNN=" + ("ON" if args.use_mkldnn else "OFF"), "-Donnxruntime_USE_MKLML=" + ("ON" if args.use_mklml else "OFF"), + "-Donnxruntime_USE_GEMMLOWP=" + ("ON" if args.use_gemmlowp else "OFF"), "-Donnxruntime_USE_NGRAPH=" + ("ON" if args.use_ngraph else "OFF"), "-Donnxruntime_USE_OPENVINO=" + ("ON" if args.use_openvino else "OFF"), "-Donnxruntime_USE_OPENVINO_BINARY=" + ("ON" if args.use_openvino else "OFF"), From fd3fe5167448a2f22e794cf2b99d000b350abec9 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Mon, 26 Aug 2019 13:21:12 -0700 Subject: [PATCH 2/2] update test --- .../test/providers/cpu/math/matmul_integer_test.cc | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc index d84ecb8562b6c..4a128d4570f23 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc @@ -11,7 +11,7 @@ namespace onnxruntime { namespace test { -TEST(MatmulIntegerOpTest, MatMulInteger1) { +TEST(MatmulIntegerOpTest, MatMulInteger_2D) { OpTester test("MatMulInteger", 10); test.AddInput("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); test.AddInput("T2", {3, 2}, {1, 4, 2, 5, 3, 6}); @@ -30,7 +30,7 @@ TEST(MatmulIntegerOpTest, MatMulInteger) { test.AddOutput("T3", {1, 1}, {-1}); test.Run(); } -TEST(MatmulIntegerOpTest, MatMulInteger12) { +TEST(MatmulIntegerOpTest, MatMulInteger_WithZero_ZeroPoint) { OpTester test("MatMulInteger", 10); test.AddInput("T1", {4, 3}, {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0}); test.AddInput("T2", {3, 2}, {1, 4, 2, 5, 3, 6}); @@ -40,14 +40,5 @@ TEST(MatmulIntegerOpTest, MatMulInteger12) { test.Run(); } -TEST(MatmulIntegerOpTest, MatMulInteger13) { - OpTester test("MatMulInteger", 10); - test.AddInput("T1", {1, 1}, {11}); - test.AddInput("T2", {1, 1}, {13}); - test.AddInput("a_zero_point", {}, {0}); - test.AddInput("b_zero_point", {}, {0}); - test.AddOutput("T3", {1, 1}, {143}); - test.Run(); -} } // namespace test } // namespace onnxruntime \ No newline at end of file