From 0d0463a4a72b1571d95f06270d937348aca01b84 Mon Sep 17 00:00:00 2001 From: JaySon Date: Wed, 21 Aug 2024 20:09:12 +0800 Subject: [PATCH] *: Vector Data types and Functions (#9341) ref pingcap/tiflash#9032 *: Vector Data types and Functions Support parsing vector data type written by TiDB Support basic functions for vector data type: CastVectorAsText, VecDims, VecL1Distance, VecL2Distance, VecL2Norm, VecCosineDistance, VecNegativeInnerProduct Signed-off-by: Lloyd-Pottiger Co-authored-by: Lloyd-Pottiger <60744015+Lloyd-Pottiger@users.noreply.github.com> Co-authored-by: JaySon-Huang --- dbms/src/Columns/ColumnArray.cpp | 71 ++- dbms/src/Columns/ColumnArray.h | 9 + dbms/src/Columns/ColumnNullable.cpp | 29 +- dbms/src/Columns/ColumnNullable.h | 4 +- dbms/src/Columns/IColumn.h | 5 + dbms/src/DataTypes/DataTypeArray.h | 4 +- dbms/src/Debug/MockExecutor/AstToPB.cpp | 7 + dbms/src/Debug/dbgTools.cpp | 4 + dbms/src/Flash/Coprocessor/ArrowColCodec.cpp | 86 ++++ dbms/src/Flash/Coprocessor/DAGCodec.cpp | 11 + dbms/src/Flash/Coprocessor/DAGCodec.h | 2 + .../Coprocessor/DAGExpressionAnalyzer.cpp | 6 + dbms/src/Flash/Coprocessor/DAGUtils.cpp | 62 +++ dbms/src/Flash/Coprocessor/DAGUtils.h | 1 + dbms/src/Flash/Coprocessor/TiDBChunk.cpp | 2 +- dbms/src/Flash/Coprocessor/TiDBColumn.cpp | 16 + dbms/src/Flash/Coprocessor/TiDBColumn.h | 8 +- .../tests/gtest_streaming_writer.cpp | 4 +- dbms/src/Functions/FunctionHelpers.h | 11 + dbms/src/Functions/FunctionsVector.cpp | 44 ++ dbms/src/Functions/FunctionsVector.h | 472 ++++++++++++++++++ dbms/src/Functions/registerFunctions.cpp | 2 + dbms/src/Functions/tests/gtest_vector.cpp | 361 ++++++++++++++ .../DeltaMerge/FilterParser/FilterParser.cpp | 1 + dbms/src/TiDB/Decode/DatumCodec.cpp | 48 ++ dbms/src/TiDB/Decode/DatumCodec.h | 4 + dbms/src/TiDB/Decode/RowCodec.cpp | 3 + dbms/src/TiDB/Decode/TypeMapping.cpp | 28 +- dbms/src/TiDB/Decode/Vector.cpp | 184 +++++++ dbms/src/TiDB/Decode/Vector.h | 68 +++ dbms/src/TiDB/Schema/TiDB.cpp | 3 + dbms/src/TiDB/Schema/TiDBTypes.h | 31 +- 32 files changed, 1562 insertions(+), 29 deletions(-) create mode 100644 dbms/src/Functions/FunctionsVector.cpp create mode 100644 dbms/src/Functions/FunctionsVector.h create mode 100644 dbms/src/Functions/tests/gtest_vector.cpp create mode 100644 dbms/src/TiDB/Decode/Vector.cpp create mode 100644 dbms/src/TiDB/Decode/Vector.h diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index 63bdc4df1df..6c5d57e3006 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -25,8 +25,13 @@ #include #include #include +#include +#include +#include #include // memcpy +#include + namespace DB { namespace ErrorCodes @@ -798,10 +803,44 @@ void ColumnArray::getPermutation(bool reverse, size_t limit, int nan_direction_h } } -ColumnPtr ColumnArray::replicateRange(size_t /*start_row*/, size_t /*end_row*/, const IColumn::Offsets & /*offsets*/) +ColumnPtr ColumnArray::replicateRange(size_t start_row, size_t end_row, const IColumn::Offsets & replicate_offsets) const { - throw Exception("not implement.", ErrorCodes::NOT_IMPLEMENTED); + size_t col_size = size(); + if (col_size != replicate_offsets.size()) + throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); + + // We only support replicate to full column. + RUNTIME_CHECK(start_row == 0, start_row); + RUNTIME_CHECK(end_row == replicate_offsets.size(), end_row, replicate_offsets.size()); + + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateConst(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNullable(replicate_offsets); + return replicateGeneric(replicate_offsets); } @@ -1048,4 +1087,32 @@ void ColumnArray::gather(ColumnGathererStream & gatherer) gatherer.gather(*this); } +bool ColumnArray::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool /* force_decode */) +{ + RUNTIME_CHECK(raw_value.size() >= cursor + length); + insertFromDatumData(raw_value.c_str() + cursor, length); + return true; +} + +void ColumnArray::insertFromDatumData(const char * data, size_t length) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + RUNTIME_CHECK(checkAndGetColumn>(&getData())); + RUNTIME_CHECK(getData().isFixedAndContiguous()); + + RUNTIME_CHECK(length >= sizeof(UInt32), length); + auto n = readLittleEndian(data); + data += sizeof(UInt32); + + auto precise_data_size = n * sizeof(Float32); + RUNTIME_CHECK(length >= sizeof(UInt32) + precise_data_size, n, length); + insertData(data, precise_data_size); +} + +std::pair ColumnArray::getElementRef(size_t element_idx) const +{ + return {static_cast(sizeAt(element_idx)), getDataAt(element_idx)}; +} + } // namespace DB diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index 564637f3595..852c15f6ada 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -167,6 +167,15 @@ class ColumnArray final : public COWPtrHelper callback(data); } + bool canBeInsideNullable() const override { return true; } + + bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t /* length */, bool /* force_decode */) + override; + + void insertFromDatumData(const char * data, size_t length) override; + + std::pair getElementRef(size_t element_idx) const; + private: ColumnPtr data; ColumnPtr offsets; diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index ecdfd60d57e..a18ff7bcaf1 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -204,14 +204,29 @@ void ColumnNullable::get(size_t n, Field & res) const getNestedColumn().get(n, res); } -StringRef ColumnNullable::getDataAt(size_t /*n*/) const +StringRef ColumnNullable::getDataAt(size_t n) const { - throw Exception(fmt::format("Method getDataAt is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED); + if (likely(!isNullAt(n))) + return getNestedColumn().getDataAt(n); + + throw Exception( + ErrorCodes::NOT_IMPLEMENTED, + "Method getDataAt is not supported for {} in case if value is NULL", + getName()); } -void ColumnNullable::insertData(const char * /*pos*/, size_t /*length*/) +void ColumnNullable::insertData(const char * pos, size_t length) { - throw Exception(fmt::format("Method insertData is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED); + if (pos == nullptr) + { + getNestedColumn().insertDefault(); + getNullMapData().push_back(1); + } + else + { + getNestedColumn().insertData(pos, length); + getNullMapData().push_back(0); + } } bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode) @@ -222,6 +237,12 @@ bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_valu return true; } +void ColumnNullable::insertFromDatumData(const char * cursor, size_t len) +{ + getNestedColumn().insertFromDatumData(cursor, len); + getNullMapData().push_back(0); +} + StringRef ColumnNullable::serializeValueIntoArena( size_t n, Arena & arena, diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index 0622ce78a0c..f06b8c30b9b 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -65,9 +65,11 @@ class ColumnNullable final : public COWPtrHelper Field operator[](size_t n) const override; void get(size_t n, Field & res) const override; UInt64 get64(size_t n) const override { return nested_column->get64(n); } - StringRef getDataAt(size_t n) const override; + StringRef getDataAt(size_t) const override; + /// Will insert null value if pos=nullptr void insertData(const char * pos, size_t length) override; bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode) override; + void insertFromDatumData(const char *, size_t) override; StringRef serializeValueIntoArena( size_t n, Arena & arena, diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 6e5db8abf2c..06a906fdb63 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -173,6 +173,11 @@ class IColumn : public COWPtr throw Exception("Method decodeTiDBRowV2Datum is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + virtual void insertFromDatumData(const char *, size_t) + { + throw Exception("Method insertFromDatumData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } + /// Like getData, but has special behavior for columns that contain variable-length strings. /// In this special case inserting data should be zero-ending (i.e. length is 1 byte greater than real string size). virtual void insertDataWithTerminatingZero(const char * pos, size_t length) { insertData(pos, length); } diff --git a/dbms/src/DataTypes/DataTypeArray.h b/dbms/src/DataTypes/DataTypeArray.h index 4c7572fdc74..bd71e9233f3 100644 --- a/dbms/src/DataTypes/DataTypeArray.h +++ b/dbms/src/DataTypes/DataTypeArray.h @@ -34,7 +34,7 @@ class DataTypeArray final : public IDataType const char * getFamilyName() const override { return "Array"; } - bool canBeInsideNullable() const override { return false; } + bool canBeInsideNullable() const override { return true; } TypeIndex getTypeId() const override { return TypeIndex::Array; } @@ -98,7 +98,7 @@ class DataTypeArray final : public IDataType bool haveSubtypes() const override { return true; } bool cannotBeStoredInTables() const override { return nested->cannotBeStoredInTables(); } bool textCanContainOnlyValidUTF8() const override { return nested->textCanContainOnlyValidUTF8(); } - bool isComparable() const override { return nested->isComparable(); }; + bool isComparable() const override { return nested->isComparable(); } bool canBeComparedWithCollation() const override { return nested->canBeComparedWithCollation(); } bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override diff --git a/dbms/src/Debug/MockExecutor/AstToPB.cpp b/dbms/src/Debug/MockExecutor/AstToPB.cpp index 4756e046473..5a2a77a719c 100644 --- a/dbms/src/Debug/MockExecutor/AstToPB.cpp +++ b/dbms/src/Debug/MockExecutor/AstToPB.cpp @@ -110,6 +110,13 @@ void literalFieldToTiPBExpr(const TiDB::ColumnInfo & ci, const Field & val_field encodeDAGInt64(val, ss); break; } + case TiDB::TypeTiDBVectorFloat32: + { + expr->set_tp(tipb::ExprType::TiDBVectorFloat32); + const auto & val = val_field.safeGet(); + encodeDAGVectorFloat32(val, ss); + break; + } default: throw Exception(fmt::format( "Type {} does not support literal in function unit test", diff --git a/dbms/src/Debug/dbgTools.cpp b/dbms/src/Debug/dbgTools.cpp index 82d8b6e6801..f896ab6346d 100644 --- a/dbms/src/Debug/dbgTools.cpp +++ b/dbms/src/Debug/dbgTools.cpp @@ -469,6 +469,10 @@ struct BatchCtrl throw Exception( "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagJson", ErrorCodes::LOGICAL_ERROR); + case TiDB::CodecFlagVectorFloat32: + throw Exception( + "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagVectorFloat32", + ErrorCodes::LOGICAL_ERROR); case TiDB::CodecFlagMax: throw Exception("Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagMax", ErrorCodes::LOGICAL_ERROR); case TiDB::CodecFlagDuration: diff --git a/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp b/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp index 1496a9861c5..bd8f408923a 100644 --- a/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp +++ b/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include +#include #include #include #include @@ -296,6 +298,37 @@ void flashStringColToArrowCol( } } +template +void flashArrayFloat32ColToArrowCol( + TiDBColumn & dag_column, + const IColumn * flash_col_untyped, + size_t start_index, + size_t end_index) +{ + // We only unwrap the NULLABLE() part. + const IColumn * nested_col = getNestedCol(flash_col_untyped); + const auto * flash_col = checkAndGetColumn(nested_col); + + RUNTIME_CHECK(checkAndGetColumn>(&flash_col->getData())); + RUNTIME_CHECK(flash_col->getData().isFixedAndContiguous()); + + for (size_t i = start_index; i < end_index; i++) + { + // todo check if we can convert flash_col to DAG col directly since the internal representation is almost the same + if constexpr (is_nullable) + { + if (flash_col_untyped->isNullAt(i)) + { + dag_column.appendNull(); + continue; + } + } + + auto [num_elems, elem_bytes] = flash_col->getElementRef(i); + dag_column.appendVectorF32(num_elems, elem_bytes); + } +} + template void flashBitColToArrowCol( TiDBColumn & dag_column, @@ -465,6 +498,20 @@ void flashColToArrowCol( else flashStringColToArrowCol(dag_column, col, start_index, end_index); break; + case TiDB::TypeTiDBVectorFloat32: + { + const auto * data_type = checkAndGetDataType(type); + if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32) + throw TiFlashException( + Errors::Coprocessor::Internal, + "Type un-matched during arrow encode, target col type is array and source column type is {}", + type->getName()); + if (tidb_column_info.hasNotNullFlag()) + flashArrayFloat32ColToArrowCol(dag_column, col, start_index, end_index); + else + flashArrayFloat32ColToArrowCol(dag_column, col, start_index, end_index); + break; + } case TiDB::TypeBit: if (!checkDataType(type)) throw TiFlashException( @@ -529,6 +576,35 @@ const char * arrowStringColToFlashCol( return pos + offsets[length]; } +const char * arrowArrayFloat32ColToFlashCol( + const char * pos, + UInt8, + UInt32 null_count, + const std::vector & null_bitmap, + const std::vector & offsets, + const ColumnWithTypeAndName & col, + const ColumnInfo &, + UInt32 length) +{ + const auto * data_type = checkAndGetDataType(&*col.type); + if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32) + throw TiFlashException( + Errors::Coprocessor::Internal, + "Type un-matched during arrow decode, target col type is array and source column type is {}", + col.type->getName()); + + for (UInt32 i = 0; i < length; i++) + { + if (checkNull(i, null_count, null_bitmap, col)) + continue; + + auto arrow_data_size = offsets[i + 1] - offsets[i]; + const auto * base_offset = pos + offsets[i]; + col.column->assumeMutable()->insertFromDatumData(base_offset, arrow_data_size); + } + return pos + offsets[length]; +} + const char * arrowEnumColToFlashCol( const char * pos, UInt8, @@ -823,6 +899,16 @@ const char * arrowColToFlashCol( length); case TiDB::TypeBit: return arrowBitColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length); + case TiDB::TypeTiDBVectorFloat32: + return arrowArrayFloat32ColToFlashCol( + pos, + field_length, + null_count, + null_bitmap, + offsets, + flash_col, + col_info, + length); case TiDB::TypeEnum: return arrowEnumColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length); default: diff --git a/dbms/src/Flash/Coprocessor/DAGCodec.cpp b/dbms/src/Flash/Coprocessor/DAGCodec.cpp index ef8dc4d7c2e..2b3e7ce10ec 100644 --- a/dbms/src/Flash/Coprocessor/DAGCodec.cpp +++ b/dbms/src/Flash/Coprocessor/DAGCodec.cpp @@ -53,6 +53,11 @@ void encodeDAGDecimal(const Field & field, WriteBuffer & ss) EncodeDecimal(field, ss); } +void encodeDAGVectorFloat32(const Array & v, WriteBuffer & ss) +{ + EncodeVectorFloat32(v, ss); +} + Int64 decodeDAGInt64(const String & s) { auto u = *(reinterpret_cast(s.data())); @@ -93,4 +98,10 @@ Field decodeDAGDecimal(const String & s) return DecodeDecimal(cursor, s); } +Field decodeDAGVectorFloat32(const String & s) +{ + size_t cursor = 0; + return DecodeVectorFloat32(cursor, s); +} + } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGCodec.h b/dbms/src/Flash/Coprocessor/DAGCodec.h index 66cf4b83eda..e0fb33b703c 100644 --- a/dbms/src/Flash/Coprocessor/DAGCodec.h +++ b/dbms/src/Flash/Coprocessor/DAGCodec.h @@ -26,6 +26,7 @@ void encodeDAGFloat64(Float64, WriteBuffer &); void encodeDAGString(const String &, WriteBuffer &); void encodeDAGBytes(const String &, WriteBuffer &); void encodeDAGDecimal(const Field &, WriteBuffer &); +void encodeDAGVectorFloat32(const Array &, WriteBuffer &); Int64 decodeDAGInt64(const String &); UInt64 decodeDAGUInt64(const String &); @@ -34,5 +35,6 @@ Float64 decodeDAGFloat64(const String &); String decodeDAGString(const String &); String decodeDAGBytes(const String &); Field decodeDAGDecimal(const String &); +Field decodeDAGVectorFloat32(const String &); } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 6420be8a69c..095d2fc8813 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -1050,6 +1050,12 @@ String DAGExpressionAnalyzer::convertToUInt8(const ExpressionActionsPtr & action auto const_expr_name = getActions(const_expr, actions); return applyFunction("notEquals", {column_name, const_expr_name}, actions, nullptr); } + else if (checkDataTypeArray(org_type.get())) + { + tipb::Expr const_expr = constructZeroVectorFloat32TiExpr(); + auto const_expr_name = getActions(const_expr, actions); + return applyFunction("notEquals", {column_name, const_expr_name}, actions, nullptr); + } throw TiFlashException( fmt::format("Filter on {} is not supported.", org_type->getName()), Errors::Coprocessor::Unimplemented); diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 0c6001658c0..399ed870c7e 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -21,8 +21,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -132,6 +134,9 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::CastJsonAsDuration, "cast"}, {tipb::ScalarFuncSig::CastJsonAsJson, "cast_json_as_json"}, + {tipb::ScalarFuncSig::CastVectorFloat32AsString, "cast_vector_float32_as_string"}, + {tipb::ScalarFuncSig::CastVectorFloat32AsVectorFloat32, "cast_vector_float32_as_vector_float32"}, + {tipb::ScalarFuncSig::CoalesceInt, "coalesce"}, {tipb::ScalarFuncSig::CoalesceReal, "coalesce"}, {tipb::ScalarFuncSig::CoalesceString, "coalesce"}, @@ -147,6 +152,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::LTTime, "less"}, {tipb::ScalarFuncSig::LTDuration, "less"}, {tipb::ScalarFuncSig::LTJson, "less"}, + {tipb::ScalarFuncSig::LTVectorFloat32, "less"}, {tipb::ScalarFuncSig::LEInt, "lessOrEquals"}, {tipb::ScalarFuncSig::LEReal, "lessOrEquals"}, @@ -155,6 +161,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::LETime, "lessOrEquals"}, {tipb::ScalarFuncSig::LEDuration, "lessOrEquals"}, {tipb::ScalarFuncSig::LEJson, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEVectorFloat32, "lessOrEquals"}, {tipb::ScalarFuncSig::GTInt, "greater"}, {tipb::ScalarFuncSig::GTReal, "greater"}, @@ -163,6 +170,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::GTTime, "greater"}, {tipb::ScalarFuncSig::GTDuration, "greater"}, {tipb::ScalarFuncSig::GTJson, "greater"}, + {tipb::ScalarFuncSig::GTVectorFloat32, "greater"}, {tipb::ScalarFuncSig::GreatestInt, "tidbGreatest"}, {tipb::ScalarFuncSig::GreatestReal, "tidbGreatest"}, @@ -186,6 +194,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::GETime, "greaterOrEquals"}, {tipb::ScalarFuncSig::GEDuration, "greaterOrEquals"}, {tipb::ScalarFuncSig::GEJson, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEVectorFloat32, "greaterOrEquals"}, {tipb::ScalarFuncSig::EQInt, "equals"}, {tipb::ScalarFuncSig::EQReal, "equals"}, @@ -194,6 +203,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::EQTime, "equals"}, {tipb::ScalarFuncSig::EQDuration, "equals"}, {tipb::ScalarFuncSig::EQJson, "equals"}, + {tipb::ScalarFuncSig::EQVectorFloat32, "equals"}, {tipb::ScalarFuncSig::NEInt, "notEquals"}, {tipb::ScalarFuncSig::NEReal, "notEquals"}, @@ -202,6 +212,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::NETime, "notEquals"}, {tipb::ScalarFuncSig::NEDuration, "notEquals"}, {tipb::ScalarFuncSig::NEJson, "notEquals"}, + {tipb::ScalarFuncSig::NEVectorFloat32, "notEquals"}, //{tipb::ScalarFuncSig::NullEQInt, "cast"}, //{tipb::ScalarFuncSig::NullEQReal, "cast"}, @@ -319,6 +330,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::TimeIsNull, "isNull"}, {tipb::ScalarFuncSig::IntIsNull, "isNull"}, {tipb::ScalarFuncSig::JsonIsNull, "isNull"}, + {tipb::ScalarFuncSig::VectorFloat32IsNull, "isNull"}, {tipb::ScalarFuncSig::BitAndSig, "bitAnd"}, {tipb::ScalarFuncSig::BitOrSig, "bitOr"}, @@ -689,6 +701,14 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::CharLength, "upper"}, {tipb::ScalarFuncSig::GroupingSig, "grouping"}, + + {tipb::ScalarFuncSig::VecAsTextSig, "vecAsText"}, + {tipb::ScalarFuncSig::VecDimsSig, "vecDims"}, + {tipb::ScalarFuncSig::VecL1DistanceSig, "vecL1Distance"}, + {tipb::ScalarFuncSig::VecL2DistanceSig, "vecL2Distance"}, + {tipb::ScalarFuncSig::VecNegativeInnerProductSig, "vecNegativeInnerProduct"}, + {tipb::ScalarFuncSig::VecCosineDistanceSig, "vecCosineDistance"}, + {tipb::ScalarFuncSig::VecL2NormSig, "vecL2Norm"}, }); template @@ -951,6 +971,24 @@ String exprToString(const tipb::Expr & expr, const std::vector = std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); return ret; } + case tipb::ExprType::TiDBVectorFloat32: + { + if (!expr.has_field_type()) + throw TiFlashException( + "MySQL Duration literal without field_type" + expr.DebugString(), + Errors::Coprocessor::BadRequest); + auto t = decodeDAGVectorFloat32(expr.val()); + auto arr = t.safeGet(); + String ret = "["; + for (size_t i = 0; i < arr.size(); ++i) + { + if (i > 0) + ret += ","; + ret += std::to_string(arr[i].safeGet::Type>()); + } + ret += "]"; + return ret; + } case tipb::ExprType::ColumnRef: return getColumnNameForColumnExpr(expr, input_col); case tipb::ExprType::Count: @@ -1088,6 +1126,7 @@ bool isLiteralExpr(const tipb::Expr & expr) case tipb::ExprType::MysqlTime: case tipb::ExprType::MysqlJson: case tipb::ExprType::ValueList: + case tipb::ExprType::TiDBVectorFloat32: return true; default: return false; @@ -1137,6 +1176,14 @@ Field decodeLiteral(const tipb::Expr & expr) auto t = decodeDAGInt64(expr.val()); return TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field(); } + case tipb::ExprType::TiDBVectorFloat32: + { + if (!expr.has_field_type()) + throw TiFlashException( + "MySQL Duration literal without field_type" + expr.DebugString(), + Errors::Coprocessor::BadRequest); + return decodeDAGVectorFloat32(expr.val()); + } case tipb::ExprType::MysqlBit: case tipb::ExprType::MysqlEnum: case tipb::ExprType::MysqlHex: @@ -1327,6 +1374,7 @@ UInt8 getFieldLengthForArrowEncode(Int32 tp) case TiDB::TypeBit: case TiDB::TypeEnum: case TiDB::TypeJSON: + case TiDB::TypeTiDBVectorFloat32: return VAR_SIZE; default: throw TiFlashException( @@ -1381,6 +1429,20 @@ tipb::Expr constructNULLLiteralTiExpr() return expr; } +tipb::Expr constructZeroVectorFloat32TiExpr() +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + tipb::Expr expr; + expr.set_tp(tipb::ExprType::TiDBVectorFloat32); + WriteBufferFromOwnString ss; + writeIntBinary(static_cast(0), ss); + expr.set_val(ss.releaseStr()); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(TiDB::TypeTiDBVectorFloat32); + field_type->set_flag(TiDB::ColumnFlagNotNull); + return expr; +} + TiDB::TiDBCollatorPtr getCollatorFromExpr(const tipb::Expr & expr) { if (expr.has_field_type()) diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.h b/dbms/src/Flash/Coprocessor/DAGUtils.h index e26f9f99481..249dae8c504 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.h +++ b/dbms/src/Flash/Coprocessor/DAGUtils.h @@ -60,6 +60,7 @@ tipb::Expr constructStringLiteralTiExpr(const String & value); tipb::Expr constructInt64LiteralTiExpr(Int64 value); tipb::Expr constructDateTimeLiteralTiExpr(UInt64 packed_value); tipb::Expr constructNULLLiteralTiExpr(); +tipb::Expr constructZeroVectorFloat32TiExpr(); DataTypePtr inferDataType4Literal(const tipb::Expr & expr); SortDescription getSortDescription( const std::vector & order_columns, diff --git a/dbms/src/Flash/Coprocessor/TiDBChunk.cpp b/dbms/src/Flash/Coprocessor/TiDBChunk.cpp index f3a59350bdb..30d7aaea537 100644 --- a/dbms/src/Flash/Coprocessor/TiDBChunk.cpp +++ b/dbms/src/Flash/Coprocessor/TiDBChunk.cpp @@ -31,7 +31,7 @@ namespace DB { TiDBChunk::TiDBChunk(const std::vector & field_types) { - for (auto & type : field_types) + for (const auto & type : field_types) { columns.emplace_back(getFieldLengthForArrowEncode(type.tp())); } diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.cpp b/dbms/src/Flash/Coprocessor/TiDBColumn.cpp index 373e2a7320a..169a264b1ba 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.cpp +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.cpp @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include +#include #include +#include namespace DB { @@ -118,6 +121,19 @@ void TiDBColumn::append(const TiDBEnum & ti_enum) finishAppendVar(size); } +void TiDBColumn::appendVectorF32(UInt32 num_elem, StringRef elem_bytes) +{ + encodeLittleEndian(num_elem, *data); + size_t encoded_size = sizeof(UInt32); + + RUNTIME_CHECK(elem_bytes.size == num_elem * sizeof(Float32)); + data->write(elem_bytes.data, elem_bytes.size); + encoded_size += elem_bytes.size; + + RUNTIME_CHECK(encoded_size > 0); + finishAppendVar(encoded_size); +} + void TiDBColumn::append(const TiDBBit & bit) { data->write(bit.val.data, bit.val.size); diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.h b/dbms/src/Flash/Coprocessor/TiDBColumn.h index 98bc56bad26..534c09ecbbc 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.h +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.h @@ -41,6 +41,7 @@ class TiDBColumn void append(const TiDBDecimal & decimal); void append(const TiDBBit & bit); void append(const TiDBEnum & ti_enum); + void appendVectorF32(UInt32 num_elem, StringRef elem_bytes); void encodeColumn(WriteBuffer & ss); void clear(); @@ -48,14 +49,17 @@ class TiDBColumn bool isFixed() const { return fixed_size != VAR_SIZE; } void finishAppendFixed(); void finishAppendVar(UInt32 size); + void appendNullBitMap(bool value); + // WriteBufferFromOwnString is not moveable. + std::unique_ptr data; + UInt32 length; UInt32 null_cnt; std::vector null_bitmap; std::vector var_offsets; - // WriteBufferFromOwnString is not moveable. - std::unique_ptr data; + std::string default_value; UInt64 current_data_size; Int8 fixed_size; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp index 240eea2aed0..6ce2821d659 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp @@ -41,7 +41,7 @@ class TestStreamingWriter : public testing::Test } public: - TestStreamingWriter() {} + TestStreamingWriter() = default; // Return 10 Int64 column. static std::vector makeFields() @@ -93,7 +93,7 @@ struct MockStreamWriter {} void write(tipb::SelectResponse & response) { checker(response); } - bool isWritable() const { throw Exception("Unsupport async write"); } + static bool isWritable() { throw Exception("Unsupport async write"); } private: MockStreamWriterChecker checker; diff --git a/dbms/src/Functions/FunctionHelpers.h b/dbms/src/Functions/FunctionHelpers.h index 77c5f790a79..7b9b89685de 100644 --- a/dbms/src/Functions/FunctionHelpers.h +++ b/dbms/src/Functions/FunctionHelpers.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -42,6 +43,16 @@ bool checkDataType(const IDataType * data_type) return checkAndGetDataType(data_type); } +template +bool checkDataTypeArray(const IDataType * data_type) +{ + const auto * array_type = checkAndGetDataType(data_type); + if unlikely (!array_type) + return false; + + const DataTypePtr & inner_type = array_type->getNestedType(); + return checkDataType(inner_type.get()); +} template const Type * checkAndGetColumn(const IColumn * column) diff --git a/dbms/src/Functions/FunctionsVector.cpp b/dbms/src/Functions/FunctionsVector.cpp new file mode 100644 index 00000000000..795b679cf79 --- /dev/null +++ b/dbms/src/Functions/FunctionsVector.cpp @@ -0,0 +1,44 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB +{ + +void registerFunctionsVector(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); +} + +} // namespace DB diff --git a/dbms/src/Functions/FunctionsVector.h b/dbms/src/Functions/FunctionsVector.h new file mode 100644 index 00000000000..2e830338952 --- /dev/null +++ b/dbms/src/Functions/FunctionsVector.h @@ -0,0 +1,472 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int ILLEGAL_COLUMN; +} + +class FunctionsCastVectorFloat32AsString : public IFunction +{ +public: + static constexpr auto name = "cast_vector_float32_as_string"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnString::create(); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); + col_result->insert(v.toString()); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsCastVectorFloat32AsVectorFloat32 : public IFunction +{ +public: + static constexpr auto name = "cast_vector_float32_as_vector_float32"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnArray::create(ColumnFloat32::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); // Still construct a VectorFloat32Ref to do sanity checks + UNUSED(v); + col_result->insertData(data.data, data.size); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecAsText : public IFunction +{ +public: + static constexpr auto name = "vecAsText"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnString::create(); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); + col_result->insert(v.toString()); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecDims : public IFunction +{ +public: + static constexpr auto name = "vecDims"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnUInt32::create(); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); + col_result->insert(v.size()); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecL1Distance : public IFunction +{ +public: + static constexpr auto name = "vecL1Distance"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.l1Distance(v2); + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecL2Distance : public IFunction +{ +public: + static constexpr auto name = "vecL2Distance"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.l2Distance(v2); + + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecCosineDistance : public IFunction +{ +public: + static constexpr auto name = "vecCosineDistance"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.cosineDistance(v2); + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecNegativeInnerProduct : public IFunction +{ +public: + static constexpr auto name = "vecNegativeInnerProduct"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.innerProduct(v2) * -1; + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecL2Norm : public IFunction +{ +public: + static constexpr auto name = "vecL2Norm"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto d = v1.l2Norm(); + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +} // namespace DB diff --git a/dbms/src/Functions/registerFunctions.cpp b/dbms/src/Functions/registerFunctions.cpp index 6957f887805..cb43ae7760d 100644 --- a/dbms/src/Functions/registerFunctions.cpp +++ b/dbms/src/Functions/registerFunctions.cpp @@ -50,6 +50,7 @@ void registerFunctionsRegexpInstr(FunctionFactory &); void registerFunctionsRegexpSubstr(FunctionFactory &); void registerFunctionsRegexpReplace(FunctionFactory &); void registerFunctionsGrouping(FunctionFactory &); +void registerFunctionsVector(FunctionFactory &); void registerFunctions() { @@ -83,6 +84,7 @@ void registerFunctions() registerFunctionsJson(factory); registerFunctionsIsIPAddr(factory); registerFunctionsGrouping(factory); + registerFunctionsVector(factory); } } // namespace DB diff --git a/dbms/src/Functions/tests/gtest_vector.cpp b/dbms/src/Functions/tests/gtest_vector.cpp new file mode 100644 index 00000000000..d67eb683540 --- /dev/null +++ b/dbms/src/Functions/tests/gtest_vector.cpp @@ -0,0 +1,361 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +namespace DB +{ +namespace tests +{ +class Vector : public DB::tests::FunctionTest +{ +}; + +TEST_F(Vector, Dims) +try +{ + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn>({0, 2, 3, std::nullopt}), + executeFunction( + "vecDims", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}, std::nullopt}))); + + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn({0, 2, 3}), + executeFunction( + "vecDims", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}}))); + + // Fn(Const) + ASSERT_COLUMN_EQ( + createConstColumn(3, 2), + executeFunction( + "vecDims", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, // + Array{1.0, 2.0}))); +} +CATCH + +TEST_F(Vector, L2Norm) +try +{ + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn>({0.0, 5.0, 1.0}), + executeFunction( + "vecL2Norm", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{3.0, 4.0}, Array{0.0, 1.0}}))); + + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn>({0.0, 5.0, 1.0, std::nullopt}), + executeFunction( + "vecL2Norm", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{}, Array{3.0, 4.0}, Array{0.0, 1.0}, std::nullopt}))); + + // Fn(Const) + ASSERT_COLUMN_EQ( + createConstColumn>(3, 5.0), + executeFunction( + "vecL2Norm", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, // + Array{3.0, 4.0}))); +} +CATCH + +TEST_F(Vector, L2Distance) +try +{ + // Fn(NullableColumn, Column) + ASSERT_COLUMN_EQ( + createColumn>({5.0, 1.0, INFINITY, std::nullopt}), + executeFunction( + "vecL2Distance", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{0.0, 0.0}, Array{0.0, 0.0}, Array{3e38}, std::nullopt}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{-3e38}, Array{1}}))); + + // Fn(Column, Column) + ASSERT_COLUMN_EQ( + createColumn>({5.0, 1.0, INFINITY}), + executeFunction( + "vecL2Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{0.0, 0.0}, Array{0.0, 0.0}, Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{-3e38}}))); + + ASSERT_THROW( + executeFunction( + "vecL2Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); + + // Fn(Const, Const) + ASSERT_COLUMN_EQ( + createConstColumn>(3, 5.0), + executeFunction( + "vecL2Distance", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{0.0, 0.0}), + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{3.0, 4.0}))); + + // Fn(Const, Column) + ASSERT_COLUMN_EQ( + createColumn>({5.0, 1.0, 1.0}), + executeFunction( + "vecL2Distance", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{0.0, 0.0}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{0.0, 1.0}}))); +} +CATCH + +TEST_F(Vector, NegativeInnerProduct) +try +{ + ASSERT_COLUMN_EQ( + createColumn>({-11.0, -INFINITY}), + executeFunction( + "vecNegativeInnerProduct", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{3e38}}))); + + ASSERT_COLUMN_EQ( + createConstColumn>(3, -11.0), + executeFunction( + "vecNegativeInnerProduct", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{1.0, 2.0}), + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{3.0, 4.0}))); + + ASSERT_THROW( + executeFunction( + "vecNegativeInnerProduct", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); +} +CATCH + +TEST_F(Vector, CosineDistance) +try +{ + ASSERT_COLUMN_EQ( + createColumn>({0.0, std::nullopt, 0.0, 1.0, 2.0, 0.0, 2.0, std::nullopt}), + executeFunction( + "vecCosineDistance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, + Array{1.0, 2.0}, + Array{1.0, 1.0}, + Array{1.0, 0.0}, + Array{1.0, 1.0}, + Array{1.0, 1.0}, + Array{1.0, 1.0}, + Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{2.0, 4.0}, + Array{0.0, 0.0}, + Array{1.0, 1.0}, + Array{0.0, 2.0}, + Array{-1.0, -1.0}, + Array{1.1, 1.1}, + Array{-1.1, -1.1}, + Array{3e38}}))); + + ASSERT_THROW( + executeFunction( + "vecCosineDistance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); +} +CATCH + +TEST_F(Vector, L1Distance) +try +{ + ASSERT_COLUMN_EQ( + createColumn>({7.0, 1.0, INFINITY}), + executeFunction( + "vecL1Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{0.0, 0.0}, Array{0.0, 0.0}, Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{-3e38}}))); + + ASSERT_THROW( + executeFunction( + "vecL1Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); +} +CATCH + +TEST_F(Vector, IsNull) +try +{ + ASSERT_COLUMN_EQ( + createColumn({0, 1}), + executeFunction( + "isNull", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, std::nullopt}))); +} +CATCH + +TEST_F(Vector, CastAsString) +try +{ + ASSERT_COLUMN_EQ( + createColumn({"[]", "[1,2]"}), + executeFunction( + "cast_vector_float32_as_string", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}}))); +} +CATCH + +TEST_F(Vector, CastAsVector) +try +{ + ASSERT_COLUMN_EQ( + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}}), + executeFunction( + "cast_vector_float32_as_vector_float32", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}}))); +} +CATCH + +TEST_F(Vector, Compare) +try +{ + ASSERT_COLUMN_EQ( + createColumn({0, 1, 0, 1, 0, 1}), + executeFunction( + "less", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 1.0}, Array{1.0, 2.0}, Array{}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, + Array{2.0, 4.0}, + Array{0.0, 1.0}, + Array{1.0, 1.0, 1.0}, + Array{0.0, 2.0, 3.0}, + Array{0.0}}))); + + ASSERT_COLUMN_EQ( + createColumn({0, 0, 1, 0, 1, 0}), + executeFunction( + "greater", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 1.0}, Array{1.0, 2.0}, Array{}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, + Array{2.0, 4.0}, + Array{0.0, 1.0}, + Array{1.0, 1.0, 1.0}, + Array{0.0, 2.0, 3.0}, + Array{0.0}}))); + + // equals + ASSERT_COLUMN_EQ( + createColumn({1, 0, 1, 0}), + executeFunction( + "equals", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{1.0, 2.0}, Array{}, Array{}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{2.0, 4.0}, Array{}, Array{1.0, 1.0, 1.0}}))); +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp index 1d45b9c50f2..514d68524cf 100644 --- a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp +++ b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp @@ -65,6 +65,7 @@ inline bool isRoughSetFilterSupportType(const Int32 field_type) case TiDB::TypeString: return false; // Unknown. + case TiDB::TypeTiDBVectorFloat32: case TiDB::TypeDecimal: case TiDB::TypeNewDecimal: case TiDB::TypeFloat: diff --git a/dbms/src/TiDB/Decode/DatumCodec.cpp b/dbms/src/TiDB/Decode/DatumCodec.cpp index 5b7e0b55ce0..2242791e754 100644 --- a/dbms/src/TiDB/Decode/DatumCodec.cpp +++ b/dbms/src/TiDB/Decode/DatumCodec.cpp @@ -13,10 +13,13 @@ // limitations under the License. #include +#include #include +#include #include #include #include +#include #include namespace DB @@ -342,6 +345,44 @@ Field DecodeDatumForCHRow(size_t & cursor, const String & raw_value, const TiDB: } } +void EncodeVectorFloat32(const Array & val, WriteBuffer & ss) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + writeIntBinary(static_cast(val.size()), ss); + for (const auto & s : val) + writeFloatBinary(static_cast(s.safeGet::Type>()), ss); +} + +void SkipVectorFloat32(size_t & cursor, const String & raw_value) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + auto elements_n = readLittleEndian(&raw_value[cursor]); + auto size = sizeof(elements_n) + elements_n * sizeof(Float32); + cursor += size; +} + +Field DecodeVectorFloat32(size_t & cursor, const String & raw_value) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + auto n = readLittleEndian(&raw_value[cursor]); + cursor += sizeof(UInt32); + + Array res; + res.reserve(n); + + for (size_t i = 0; i < n; i++) + { + auto v = readLittleEndian(&raw_value[cursor]); + res.emplace_back(static_cast(v)); + cursor += sizeof(Float32); + } + + return res; +} + Field DecodeDatum(size_t & cursor, const String & raw_value) { switch (raw_value[cursor++]) @@ -368,6 +409,8 @@ Field DecodeDatum(size_t & cursor, const String & raw_value) return DecodeDecimal(cursor, raw_value); case TiDB::CodecFlagJson: return JsonBinary::DecodeJsonAsBinary(cursor, raw_value); + case TiDB::CodecFlagVectorFloat32: + return DecodeVectorFloat32(cursor, raw_value); default: throw Exception("Unknown Type:" + std::to_string(raw_value[cursor - 1]), ErrorCodes::LOGICAL_ERROR); } @@ -409,6 +452,9 @@ void SkipDatum(size_t & cursor, const String & raw_value) case TiDB::CodecFlagJson: JsonBinary::SkipJson(cursor, raw_value); return; + case TiDB::CodecFlagVectorFloat32: + SkipVectorFloat32(cursor, raw_value); + return; default: throw Exception("Unknown Type:" + std::to_string(raw_value[cursor - 1]), ErrorCodes::LOGICAL_ERROR); } @@ -666,6 +712,8 @@ void EncodeDatum(const Field & field, TiDB::CodecFlag flag, WriteBuffer & ss) return EncodeInt64(field.safeGet(), ss); case TiDB::CodecFlagJson: return EncodeJSON(field.safeGet(), ss); + case TiDB::CodecFlagVectorFloat32: + return EncodeVectorFloat32(field.safeGet(), ss); case TiDB::CodecFlagNil: return; default: diff --git a/dbms/src/TiDB/Decode/DatumCodec.h b/dbms/src/TiDB/Decode/DatumCodec.h index f86add34f65..055c9dd34e2 100644 --- a/dbms/src/TiDB/Decode/DatumCodec.h +++ b/dbms/src/TiDB/Decode/DatumCodec.h @@ -59,6 +59,8 @@ UInt64 DecodeVarUInt(size_t & cursor, const StringRef & raw_value); Int64 DecodeVarInt(size_t & cursor, const String & raw_value); +Field DecodeVectorFloat32(size_t & cursor, const String & raw_value); + Field DecodeDecimal(size_t & cursor, const String & raw_value); Field DecodeDecimalForCHRow(size_t & cursor, const String & raw_value, const TiDB::ColumnInfo & column_info); @@ -89,6 +91,8 @@ void EncodeCompactBytes(const String & str, WriteBuffer & ss); void EncodeJSON(const String & str, WriteBuffer & ss); +void EncodeVectorFloat32(const Array & val, WriteBuffer & ss); + void EncodeVarUInt(UInt64 num, WriteBuffer & ss); void EncodeVarInt(Int64 num, WriteBuffer & ss); diff --git a/dbms/src/TiDB/Decode/RowCodec.cpp b/dbms/src/TiDB/Decode/RowCodec.cpp index d201ef4dc82..cda807e87a7 100644 --- a/dbms/src/TiDB/Decode/RowCodec.cpp +++ b/dbms/src/TiDB/Decode/RowCodec.cpp @@ -128,6 +128,9 @@ TiKVValue::Base encodeNotNullColumn(const Field & field, const ColumnInfo & colu case TiDB::TypeLongBlob: case TiDB::TypeJSON: return field.safeGet(); + case TiDB::TypeTiDBVectorFloat32: + // unsupported, only used in tests. + throw Exception("unsupported encode TiDBVectorFloat32"); case TiDB::TypeNewDecimal: EncodeDecimalForRow(field, ss, column_info); break; diff --git a/dbms/src/TiDB/Decode/TypeMapping.cpp b/dbms/src/TiDB/Decode/TypeMapping.cpp index 0ea4b5bc4dd..705c75e10a4 100644 --- a/dbms/src/TiDB/Decode/TypeMapping.cpp +++ b/dbms/src/TiDB/Decode/TypeMapping.cpp @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include +#include #include #include #include @@ -34,6 +36,8 @@ #include #include +#include +#include #include namespace DB @@ -106,9 +110,21 @@ struct EnumType : public std::true_type template inline constexpr bool IsEnumType = EnumType::value; +template +struct ArrayType : public std::false_type +{ +}; +template <> +struct ArrayType : public std::true_type +{ +}; +template +inline constexpr bool IsArrayType = ArrayType::value; + template std::enable_if_t< - !IsSignedType && !IsDecimalType && !IsEnumType && !std::is_same_v, + !IsSignedType && !IsDecimalType && !IsEnumType && !std::is_same_v + && !IsArrayType, DataTypePtr> // getDataTypeByColumnInfoBase(const ColumnInfo &, const T *) { @@ -134,6 +150,13 @@ std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(cons return createDecimal(column_info.flen, column_info.decimal); } +template +std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo & column_info, const T *) +{ + RUNTIME_CHECK(column_info.tp == TiDB::TypeTiDBVectorFloat32, magic_enum::enum_name(column_info.tp)); + const auto nested_type = std::make_shared(); + return std::make_shared(nested_type); +} template std::enable_if_t, DataTypePtr> // @@ -427,6 +450,9 @@ ColumnInfo reverseGetColumnInfo(const NameAndTypePair & column, ColumnID id, con case TypeIndex::Enum16: column_info.tp = TiDB::TypeEnum; break; + case TypeIndex::Array: + column_info.tp = TiDB::TypeTiDBVectorFloat32; + break; default: throw DB::Exception( "Unable reverse map TiFlash type " + nested_type->getName() + " to TiDB type", diff --git a/dbms/src/TiDB/Decode/Vector.cpp b/dbms/src/TiDB/Decode/Vector.cpp new file mode 100644 index 00000000000..6a11c5a0737 --- /dev/null +++ b/dbms/src/TiDB/Decode/Vector.cpp @@ -0,0 +1,184 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} // namespace ErrorCodes + +VectorFloat32Ref::VectorFloat32Ref(const Float32 * elements, size_t n) + : elements(elements) + , elements_n(n) +{ + for (size_t i = 0; i < n; ++i) + { + if (unlikely(std::isnan(elements[i]))) + throw Exception("NaN not allowed in vector", ErrorCodes::BAD_ARGUMENTS); + if (unlikely(std::isinf(elements[i]))) + throw Exception("infinite value not allowed in vector", ErrorCodes::BAD_ARGUMENTS); + } +} + +void VectorFloat32Ref::checkDims(VectorFloat32Ref b) const +{ + if (unlikely(size() != b.size())) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "vectors have different dimensions: {} and {}", size(), b.size()); +} + +Float64 VectorFloat32Ref::l2SquaredDistance(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + Float32 diff; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + diff = elements[i] - b[i]; + distance += diff * diff; + } + + return distance; +} + +Float64 VectorFloat32Ref::innerProduct(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + distance += elements[i] * b[i]; + } + + return distance; +} + +Float64 VectorFloat32Ref::cosineDistance(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + Float32 norma = 0.0; + Float32 normb = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + distance += elements[i] * b[i]; + norma += elements[i] * elements[i]; + normb += b[i] * b[i]; + } + + Float64 similarity + = static_cast(distance) / std::sqrt(static_cast(norma) * static_cast(normb)); + + if (std::isnan(similarity)) + { + // When norma or normb is zero, distance is zero, and similarity is NaN. + // similarity can not be Inf in this case. + return std::nan(""); + } + + similarity = std::clamp(similarity, -1.0, 1.0); + return 1.0 - similarity; +} + +Float64 VectorFloat32Ref::l1Distance(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + Float32 diff = std::abs(elements[i] - b[i]); + distance += diff; + } + + return distance; +} + +Float64 VectorFloat32Ref::l2Norm() const +{ + // Note: We align the impl with pgvector: Only l2_norm use double + // precision during calculation. + + Float64 norm = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + norm += static_cast(elements[i]) * static_cast(elements[i]); + } + + return std::sqrt(norm); +} + +std::strong_ordering VectorFloat32Ref::operator<=>(const VectorFloat32Ref & b) const +{ + auto la = size(); + auto lb = b.size(); + auto common_len = std::min(la, lb); + + const auto * va = elements; + const auto * vb = b.elements; + + for (size_t i = 0; i < common_len; i++) + { + if (va[i] < vb[i]) + return std::strong_ordering::less; + else if (va[i] > vb[i]) + return std::strong_ordering::greater; + } + return la <=> lb; +} + +String VectorFloat32Ref::toString() const +{ + WriteBufferFromOwnString write_buffer; + toStringInBuffer(write_buffer); + write_buffer.finalize(); + return write_buffer.releaseStr(); +} + +void VectorFloat32Ref::toStringInBuffer(WriteBuffer & write_buffer) const +{ + write_buffer.write('['); + for (size_t i = 0; i < elements_n; i++) + { + if (i > 0) + { + write_buffer.write(','); + } + writeFloatText(elements[i], write_buffer); + } + write_buffer.write(']'); +} + +} // namespace DB diff --git a/dbms/src/TiDB/Decode/Vector.h b/dbms/src/TiDB/Decode/Vector.h new file mode 100644 index 00000000000..6ed4578fa60 --- /dev/null +++ b/dbms/src/TiDB/Decode/Vector.h @@ -0,0 +1,68 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +namespace DB +{ + +class VectorFloat32Ref +{ +public: + explicit VectorFloat32Ref(const Float32 * elements, size_t n); + + explicit VectorFloat32Ref(const StringRef & data) + : VectorFloat32Ref(reinterpret_cast(data.data), data.size / sizeof(Float32)) + {} + + size_t size() const { return elements_n; } + + bool empty() const { return size() == 0; } + + const Float32 & operator[](size_t n) const { return elements[n]; } + + void checkDims(VectorFloat32Ref b) const; + + Float64 l2SquaredDistance(VectorFloat32Ref b) const; + + Float64 l2Distance(VectorFloat32Ref b) const { return std::sqrt(l2SquaredDistance(b)); } + + Float64 innerProduct(VectorFloat32Ref b) const; + + Float64 negativeInnerProduct(VectorFloat32Ref b) const { return innerProduct(b) * -1; } + + Float64 cosineDistance(VectorFloat32Ref b) const; + + Float64 l1Distance(VectorFloat32Ref b) const; + + Float64 l2Norm() const; + + std::strong_ordering operator<=>(const VectorFloat32Ref & b) const; + + String toString() const; + + void toStringInBuffer(WriteBuffer & write_buffer) const; + +private: + const Float32 * elements; + const size_t elements_n; +}; + +} // namespace DB diff --git a/dbms/src/TiDB/Schema/TiDB.cpp b/dbms/src/TiDB/Schema/TiDB.cpp index e294443490b..e6e728721f2 100644 --- a/dbms/src/TiDB/Schema/TiDB.cpp +++ b/dbms/src/TiDB/Schema/TiDB.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -76,6 +77,8 @@ Field GenDefaultField(const TiDB::ColumnInfo & col_info) return Field(static_cast(0)); case TiDB::CodecFlagJson: return TiDB::genJsonNull(); + case TiDB::CodecFlagVectorFloat32: + return Field(Array(0)); case TiDB::CodecFlagDuration: return Field(static_cast(0)); default: diff --git a/dbms/src/TiDB/Schema/TiDBTypes.h b/dbms/src/TiDB/Schema/TiDBTypes.h index a1dfbf27dab..1ada12060fd 100644 --- a/dbms/src/TiDB/Schema/TiDBTypes.h +++ b/dbms/src/TiDB/Schema/TiDBTypes.h @@ -16,10 +16,10 @@ namespace TiDB { - // Column types. // In format: // TiDB type, int value, codec flag, CH type. +// The int value is defined in pingcap/tidb pkg/parser/mysql/type.go #ifdef M #error "Please undefine macro M first." #endif @@ -51,7 +51,8 @@ namespace TiDB M(Blob, 0xfc, CompactBytes, String) \ M(VarString, 0xfd, CompactBytes, String) \ M(String, 0xfe, CompactBytes, String) \ - M(Geometry, 0xff, CompactBytes, String) + M(Geometry, 0xff, CompactBytes, String) \ + M(TiDBVectorFloat32, 0xe1, VectorFloat32, Array) enum TP { @@ -66,21 +67,23 @@ enum TP // Codec flags. // In format: TiDB codec flag, int value. +// Defined in pingcap/tidb pkg/util/codec/codec.go #ifdef M #error "Please undefine macro M first." #endif -#define CODEC_FLAGS(M) \ - M(Nil, 0) \ - M(Bytes, 1) \ - M(CompactBytes, 2) \ - M(Int, 3) \ - M(UInt, 4) \ - M(Float, 5) \ - M(Decimal, 6) \ - M(Duration, 7) \ - M(VarInt, 8) \ - M(VarUInt, 9) \ - M(Json, 10) \ +#define CODEC_FLAGS(M) \ + M(Nil, 0) \ + M(Bytes, 1) \ + M(CompactBytes, 2) \ + M(Int, 3) \ + M(UInt, 4) \ + M(Float, 5) \ + M(Decimal, 6) \ + M(Duration, 7) \ + M(VarInt, 8) \ + M(VarUInt, 9) \ + M(Json, 10) \ + M(VectorFloat32, 20) \ M(Max, 250) enum CodecFlag