From 0b41f9779ebb108f9939644ffa4dc2fcaf573f71 Mon Sep 17 00:00:00 2001
From: Zhigao Tong <tongzhigao@pingcap.com>
Date: Tue, 12 Jul 2022 16:10:11 +0800
Subject: [PATCH] Optimize comparision for collation `UTF8_BIN` and
 `UTF8MB4_BIN` (#5299) (#5354)

ref pingcap/tiflash#5294
---
 dbms/src/Columns/ColumnConst.h                |   3 +-
 .../Functions/CollationOperatorOptimized.h    | 210 ++++++++++++++++++
 dbms/src/Functions/FunctionsComparison.h      |  54 ++++-
 dbms/src/Storages/Transaction/Collator.cpp    |  36 ++-
 .../tidb-ci/new_collation_fullstack/expr.test |  14 ++
 5 files changed, 287 insertions(+), 30 deletions(-)
 create mode 100644 dbms/src/Functions/CollationOperatorOptimized.h

diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h
index 27283c0f24a..da071507a72 100644
--- a/dbms/src/Columns/ColumnConst.h
+++ b/dbms/src/Columns/ColumnConst.h
@@ -233,7 +233,8 @@ class ColumnConst final : public COWPtrHelper<IColumn, ColumnConst>
     template <typename T>
     T getValue() const
     {
-        return getField().safeGet<typename NearestFieldType<T>::Type>();
+        auto && tmp = getField();
+        return std::move(tmp.safeGet<typename NearestFieldType<T>::Type>());
     }
 };
 
diff --git a/dbms/src/Functions/CollationOperatorOptimized.h b/dbms/src/Functions/CollationOperatorOptimized.h
new file mode 100644
index 00000000000..395ecc5b9eb
--- /dev/null
+++ b/dbms/src/Functions/CollationOperatorOptimized.h
@@ -0,0 +1,210 @@
+// Copyright 2022 PingCAP, Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <Columns/ColumnString.h>
+#include <Core/AccurateComparison.h>
+#include <Functions/StringUtil.h>
+#include <common/StringRef.h>
+#include <common/defines.h>
+
+#include <cstddef>
+#include <string_view>
+
+
+namespace DB
+{
+
+template <typename T>
+ALWAYS_INLINE inline int signum(T val)
+{
+    return (0 < val) - (val < 0);
+}
+
+// Check equality is much faster than other comparison.
+// - check size first
+// - return 0 if equal else 1
+__attribute__((flatten, always_inline, pure)) inline uint8_t RawStrEqualCompare(const std::string_view & lhs, const std::string_view & rhs)
+{
+    return StringRef(lhs) == StringRef(rhs) ? 0 : 1;
+}
+
+// Compare str view by memcmp
+__attribute__((flatten, always_inline, pure)) inline int RawStrCompare(const std::string_view & v1, const std::string_view & v2)
+{
+    return signum(v1.compare(v2));
+}
+
+constexpr char SPACE = ' ';
+
+// Remove tail space
+__attribute__((flatten, always_inline, pure)) inline std::string_view RightTrim(const std::string_view & v)
+{
+    if (likely(v.empty() || v.back() != SPACE))
+        return v;
+    size_t end = v.find_last_not_of(SPACE);
+    return end == std::string_view::npos ? std::string_view{} : std::string_view(v.data(), end + 1);
+}
+
+__attribute__((flatten, always_inline, pure)) inline int RtrimStrCompare(const std::string_view & va, const std::string_view & vb)
+{
+    return RawStrCompare(RightTrim(va), RightTrim(vb));
+}
+
+// If true, only need to check equal or not.
+template <typename T>
+struct IsEqualRelated
+{
+    static constexpr const bool value = false;
+};
+
+// For `EqualsOp` and `NotEqualsOp`, value is true.
+template <typename... A>
+struct IsEqualRelated<DB::EqualsOp<A...>>
+{
+    static constexpr const bool value = true;
+};
+template <typename... A>
+struct IsEqualRelated<DB::NotEqualsOp<A...>>
+{
+    static constexpr const bool value = true;
+};
+
+// Loop columns and invoke callback for each pair.
+template <typename F>
+__attribute__((flatten, always_inline)) inline void LoopTwoColumns(
+    const ColumnString::Chars_t & a_data,
+    const ColumnString::Offsets & a_offsets,
+    const ColumnString::Chars_t & b_data,
+    const ColumnString::Offsets & b_offsets,
+    size_t size,
+    F && func)
+{
+    for (size_t i = 0; i < size; ++i)
+    {
+        size_t a_size = StringUtil::sizeAt(a_offsets, i) - 1;
+        size_t b_size = StringUtil::sizeAt(b_offsets, i) - 1;
+        const auto * a_ptr = reinterpret_cast<const char *>(&a_data[StringUtil::offsetAt(a_offsets, i)]);
+        const auto * b_ptr = reinterpret_cast<const char *>(&b_data[StringUtil::offsetAt(b_offsets, i)]);
+
+        func({a_ptr, a_size}, {b_ptr, b_size}, i);
+    }
+}
+
+// Loop one column and invoke callback for each pair.
+template <typename F>
+__attribute__((flatten, always_inline)) inline void LoopOneColumn(
+    const ColumnString::Chars_t & a_data,
+    const ColumnString::Offsets & a_offsets,
+    size_t size,
+    F && func)
+{
+    for (size_t i = 0; i < size; ++i)
+    {
+        size_t a_size = StringUtil::sizeAt(a_offsets, i) - 1;
+        const auto * a_ptr = reinterpret_cast<const char *>(&a_data[StringUtil::offsetAt(a_offsets, i)]);
+
+        func({a_ptr, a_size}, i);
+    }
+}
+
+// Handle str-column compare str-column.
+// - Optimize UTF8_BIN and UTF8MB4_BIN
+//   - Check if columns do NOT contain tail space
+//   - If Op is `EqualsOp` or `NotEqualsOp`, optimize comparison by faster way
+template <typename Op, typename Result>
+ALWAYS_INLINE inline bool StringVectorStringVector(
+    const ColumnString::Chars_t & a_data,
+    const ColumnString::Offsets & a_offsets,
+    const ColumnString::Chars_t & b_data,
+    const ColumnString::Offsets & b_offsets,
+    const TiDB::TiDBCollatorPtr & collator,
+    Result & c)
+{
+    bool use_optimized_path = false;
+
+    switch (collator->getCollatorId())
+    {
+    case TiDB::ITiDBCollator::UTF8MB4_BIN:
+    case TiDB::ITiDBCollator::UTF8_BIN:
+    {
+        size_t size = a_offsets.size();
+
+        LoopTwoColumns(a_data, a_offsets, b_data, b_offsets, size, [&c](const std::string_view & va, const std::string_view & vb, size_t i) {
+            if constexpr (IsEqualRelated<Op>::value)
+            {
+                c[i] = Op::apply(RawStrEqualCompare(RightTrim(va), RightTrim(vb)), 0);
+            }
+            else
+            {
+                c[i] = Op::apply(RtrimStrCompare(va, vb), 0);
+            }
+        });
+
+        use_optimized_path = true;
+
+        break;
+    }
+    default:
+        break;
+    }
+    return use_optimized_path;
+}
+
+// Handle str-column compare const-str.
+// - Optimize UTF8_BIN and UTF8MB4_BIN
+//   - Right trim const-str first
+//   - Check if column does NOT contain tail space
+//   - If Op is `EqualsOp` or `NotEqualsOp`, optimize comparison by faster way
+template <typename Op, typename Result>
+ALWAYS_INLINE inline bool StringVectorConstant(
+    const ColumnString::Chars_t & a_data,
+    const ColumnString::Offsets & a_offsets,
+    const std::string_view & b,
+    const TiDB::TiDBCollatorPtr & collator,
+    Result & c)
+{
+    bool use_optimized_path = false;
+
+    switch (collator->getCollatorId())
+    {
+    case TiDB::ITiDBCollator::UTF8MB4_BIN:
+    case TiDB::ITiDBCollator::UTF8_BIN:
+    {
+        size_t size = a_offsets.size();
+
+        std::string_view tar_str_view = RightTrim(b); // right trim const-str first
+
+        LoopOneColumn(a_data, a_offsets, size, [&c, &tar_str_view](const std::string_view & view, size_t i) {
+            if constexpr (IsEqualRelated<Op>::value)
+            {
+                c[i] = Op::apply(RawStrEqualCompare(RightTrim(view), tar_str_view), 0);
+            }
+            else
+            {
+                c[i] = Op::apply(RawStrCompare(RightTrim(view), tar_str_view), 0);
+            }
+        });
+
+        use_optimized_path = true;
+        break;
+    }
+    default:
+        break;
+    }
+    return use_optimized_path;
+}
+
+} // namespace DB
diff --git a/dbms/src/Functions/FunctionsComparison.h b/dbms/src/Functions/FunctionsComparison.h
index 1c63a286452..8f7502fba85 100644
--- a/dbms/src/Functions/FunctionsComparison.h
+++ b/dbms/src/Functions/FunctionsComparison.h
@@ -33,6 +33,7 @@
 #include <DataTypes/DataTypeString.h>
 #include <DataTypes/DataTypeTuple.h>
 #include <DataTypes/DataTypesNumber.h>
+#include <Functions/CollationOperatorOptimized.h>
 #include <Functions/FunctionHelpers.h>
 #include <Functions/FunctionsLogical.h>
 #include <Functions/IFunction.h>
@@ -301,6 +302,12 @@ struct StringComparisonWithCollatorImpl
         const TiDB::TiDBCollatorPtr & collator,
         PaddedPODArray<ResultType> & c)
     {
+        bool optimized_path = StringVectorStringVector<Op>(a_data, a_offsets, b_data, b_offsets, collator, c);
+        if (optimized_path)
+        {
+            return;
+        }
+
         size_t size = a_offsets.size();
 
         for (size_t i = 0; i < size; ++i)
@@ -317,10 +324,17 @@ struct StringComparisonWithCollatorImpl
     static void NO_INLINE stringVectorConstant(
         const ColumnString::Chars_t & a_data,
         const ColumnString::Offsets & a_offsets,
-        const std::string & b,
+        const std::string_view & b,
         const TiDB::TiDBCollatorPtr & collator,
         PaddedPODArray<ResultType> & c)
     {
+        bool optimized_path = StringVectorConstant<Op>(a_data, a_offsets, b, collator, c);
+
+        if (optimized_path)
+        {
+            return;
+        }
+
         size_t size = a_offsets.size();
         ColumnString::Offset b_size = b.size();
         const char * b_data = reinterpret_cast<const char *>(b.data());
@@ -332,7 +346,7 @@ struct StringComparisonWithCollatorImpl
     }
 
     static void constantStringVector(
-        const std::string & a,
+        const std::string_view & a,
         const ColumnString::Chars_t & b_data,
         const ColumnString::Offsets & b_offsets,
         const TiDB::TiDBCollatorPtr & collator,
@@ -342,8 +356,8 @@ struct StringComparisonWithCollatorImpl
     }
 
     static void constantConstant(
-        const std::string & a,
-        const std::string & b,
+        const std::string_view & a,
+        const std::string_view & b,
         const TiDB::TiDBCollatorPtr & collator,
         ResultType & c)
     {
@@ -706,6 +720,25 @@ class FunctionComparison : public IFunction
         }
     }
 
+    static inline std::string_view genConstStrRef(const ColumnConst * c0_const)
+    {
+        std::string_view c0_const_str_ref{};
+        if (c0_const)
+        {
+            if (const auto * c0_const_string = checkAndGetColumn<ColumnString>(&c0_const->getDataColumn()); c0_const_string)
+            {
+                c0_const_str_ref = std::string_view(c0_const_string->getDataAt(0));
+            }
+            else if (const auto * c0_const_fixed_string = checkAndGetColumn<ColumnFixedString>(&c0_const->getDataColumn()); c0_const_fixed_string)
+            {
+                c0_const_str_ref = std::string_view(c0_const_fixed_string->getDataAt(0));
+            }
+            else
+                throw Exception("Logical error: ColumnConst contains not String nor FixedString column", ErrorCodes::ILLEGAL_COLUMN);
+        }
+        return c0_const_str_ref;
+    }
+
     template <typename ResultColumnType>
     bool executeStringWithCollator(
         Block & block,
@@ -720,10 +753,13 @@ class FunctionComparison : public IFunction
         using ResultType = typename ResultColumnType::value_type;
         using StringImpl = StringComparisonWithCollatorImpl<Op<int, int>, ResultType>;
 
+        std::string_view c0_const_str_ref = genConstStrRef(c0_const);
+        std::string_view c1_const_str_ref = genConstStrRef(c1_const);
+
         if (c0_const && c1_const)
         {
             ResultType res = 0;
-            StringImpl::constantConstant(c0_const->getValue<String>(), c1_const->getValue<String>(), collator, res);
+            StringImpl::constantConstant(c0_const_str_ref, c1_const_str_ref, collator, res);
             block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(c0_const->size(), toField(res));
             return true;
         }
@@ -745,12 +781,12 @@ class FunctionComparison : public IFunction
                 StringImpl::stringVectorConstant(
                     c0_string->getChars(),
                     c0_string->getOffsets(),
-                    c1_const->getValue<String>(),
+                    c1_const_str_ref,
                     collator,
                     c_res->getData());
             else if (c0_const && c1_string)
                 StringImpl::constantStringVector(
-                    c0_const->getValue<String>(),
+                    c0_const_str_ref,
                     c1_string->getChars(),
                     c1_string->getOffsets(),
                     collator,
@@ -770,8 +806,8 @@ class FunctionComparison : public IFunction
     template <typename ReturnColumnType = ColumnUInt8>
     bool executeString(Block & block, size_t result, const IColumn * c0, const IColumn * c1) const
     {
-        const ColumnString * c0_string = checkAndGetColumn<ColumnString>(c0);
-        const ColumnString * c1_string = checkAndGetColumn<ColumnString>(c1);
+        const auto * c0_string = checkAndGetColumn<ColumnString>(c0);
+        const auto * c1_string = checkAndGetColumn<ColumnString>(c1);
         const ColumnConst * c0_const = checkAndGetColumnConstStringOrFixedString(c0);
         const ColumnConst * c1_const = checkAndGetColumnConstStringOrFixedString(c1);
 
diff --git a/dbms/src/Storages/Transaction/Collator.cpp b/dbms/src/Storages/Transaction/Collator.cpp
index a9b4d0784be..1b0221a6829 100644
--- a/dbms/src/Storages/Transaction/Collator.cpp
+++ b/dbms/src/Storages/Transaction/Collator.cpp
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include <Common/Exception.h>
+#include <Functions/CollationOperatorOptimized.h>
 #include <Poco/String.h>
 #include <Storages/Transaction/Collator.h>
 
@@ -29,17 +30,10 @@ TiDBCollators dummy_collators;
 std::vector<std::string> dummy_sort_key_contaners;
 std::string dummy_sort_key_contaner;
 
-std::string_view rtrim(const char * s, size_t length)
+ALWAYS_INLINE std::string_view rtrim(const char * s, size_t length)
 {
     auto v = std::string_view(s, length);
-    size_t end = v.find_last_not_of(' ');
-    return end == std::string_view::npos ? "" : v.substr(0, end + 1);
-}
-
-template <typename T>
-int signum(T val)
-{
-    return (0 < val) - (val < 0);
+    return DB::RightTrim(v);
 }
 
 using Rune = int32_t;
@@ -183,26 +177,26 @@ class Pattern : public ITiDBCollator::IPattern
 };
 
 template <typename T, bool padding = false>
-class BinCollator : public ITiDBCollator
+class BinCollator final : public ITiDBCollator
 {
 public:
     explicit BinCollator(int32_t id)
         : ITiDBCollator(id)
     {}
+
     int compare(const char * s1, size_t length1, const char * s2, size_t length2) const override
     {
         if constexpr (padding)
-            return signum(rtrim(s1, length1).compare(rtrim(s2, length2)));
+            return DB::RtrimStrCompare({s1, length1}, {s2, length2});
         else
-            return signum(std::string_view(s1, length1).compare(std::string_view(s2, length2)));
+            return DB::RawStrCompare({s1, length1}, {s2, length2});
     }
 
     StringRef sortKey(const char * s, size_t length, std::string &) const override
     {
         if constexpr (padding)
         {
-            auto v = rtrim(s, length);
-            return StringRef(v.data(), v.length());
+            return StringRef(rtrim(s, length));
         }
         else
         {
@@ -249,7 +243,7 @@ using WeightType = uint16_t;
 extern const std::array<WeightType, 256 * 256> weight_lut;
 } // namespace GeneralCI
 
-class GeneralCICollator : public ITiDBCollator
+class GeneralCICollator final : public ITiDBCollator
 {
 public:
     explicit GeneralCICollator(int32_t id)
@@ -270,7 +264,7 @@ class GeneralCICollator : public ITiDBCollator
             auto sk2 = weight(c2);
             auto cmp = sk1 - sk2;
             if (cmp != 0)
-                return signum(cmp);
+                return DB::signum(cmp);
         }
 
         return (offset1 < v1.length()) - (offset2 < v2.length());
@@ -365,7 +359,7 @@ const std::array<long_weight, 23> weight_lut_long = {
 
 } // namespace UnicodeCI
 
-class UnicodeCICollator : public ITiDBCollator
+class UnicodeCICollator final : public ITiDBCollator
 {
 public:
     explicit UnicodeCICollator(int32_t id)
@@ -420,7 +414,7 @@ class UnicodeCICollator : public ITiDBCollator
                 }
                 else
                 {
-                    return signum(static_cast<int>(s1_first & 0xFFFF) - static_cast<int>(s2_first & 0xFFFF));
+                    return DB::signum(static_cast<int>(s1_first & 0xFFFF) - static_cast<int>(s2_first & 0xFFFF));
                 }
             }
         }
@@ -593,6 +587,8 @@ class UnicodeCICollator : public ITiDBCollator
     friend class Pattern<UnicodeCICollator>;
 };
 
+using UTF8MB4_BIN_TYPE = BinCollator<Rune, true>;
+
 TiDBCollatorPtr ITiDBCollator::getCollator(int32_t id)
 {
     switch (id)
@@ -607,10 +603,10 @@ TiDBCollatorPtr ITiDBCollator::getCollator(int32_t id)
         static const auto latin1_collator = BinCollator<char, true>(LATIN1_BIN);
         return &latin1_collator;
     case ITiDBCollator::UTF8MB4_BIN:
-        static const auto utf8mb4_collator = BinCollator<Rune, true>(UTF8MB4_BIN);
+        static const auto utf8mb4_collator = UTF8MB4_BIN_TYPE(UTF8MB4_BIN);
         return &utf8mb4_collator;
     case ITiDBCollator::UTF8_BIN:
-        static const auto utf8_collator = BinCollator<Rune, true>(UTF8_BIN);
+        static const auto utf8_collator = UTF8MB4_BIN_TYPE(UTF8_BIN);
         return &utf8_collator;
     case ITiDBCollator::UTF8_GENERAL_CI:
         static const auto utf8_general_ci_collator = GeneralCICollator(UTF8_GENERAL_CI);
diff --git a/tests/tidb-ci/new_collation_fullstack/expr.test b/tests/tidb-ci/new_collation_fullstack/expr.test
index 15ada0f335c..1e2135c4f2d 100644
--- a/tests/tidb-ci/new_collation_fullstack/expr.test
+++ b/tests/tidb-ci/new_collation_fullstack/expr.test
@@ -35,6 +35,13 @@ mysql> set session tidb_isolation_read_engines='tiflash'; select /*+ read_from_s
 |    2 | abc   |
 +------+-------+
 
+mysql> set session tidb_isolation_read_engines='tiflash'; select /*+ read_from_storage(tiflash[t]) */ id, value1 from test.t where value1 = 'abc       ';
++------+-------+
+| id   | value1|
++------+-------+
+|    1 | abc   |
+|    2 | abc   |
++------+-------+
 
 mysql> set session tidb_isolation_read_engines='tiflash'; select /*+ read_from_storage(tiflash[t]) */ id, value from test.t where value like 'aB%';
 +------+-------+
@@ -62,6 +69,13 @@ mysql> set session tidb_isolation_read_engines='tiflash'; select /*+ read_from_s
 |    3 | def   |
 +------+-------+
 
+mysql> set session tidb_isolation_read_engines='tiflash'; select /*+ read_from_storage(tiflash[t]) */ id, value1 from test.t where value1 = 'def       ';
++------+-------+
+| id   | value1|
++------+-------+
+|    3 | def   |
++------+-------+
+
 mysql> set session tidb_isolation_read_engines='tiflash'; select /*+ read_from_storage(tiflash[t]) */ id, value1 from test.t where value1 in ('Abc','def');
 +------+-------+
 | id   | value1|