Skip to content

Commit

Permalink
*: use SimSIMD for vectors (#9372)
Browse files Browse the repository at this point in the history
ref #9032

*: use SimSIMD for vectors

Signed-off-by: Lloyd-Pottiger <[email protected]>
  • Loading branch information
Lloyd-Pottiger authored Aug 27, 2024
1 parent 3f15689 commit 5f08ae6
Show file tree
Hide file tree
Showing 23 changed files with 404 additions and 73 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@
[submodule "contrib/usearch"]
path = contrib/usearch
url = https://github.com/unum-cloud/usearch.git
[submodule "contrib/simsimd"]
path = contrib/simsimd
url = https://github.com/ashvardanian/SimSIMD
2 changes: 1 addition & 1 deletion cmake/cpu_features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ elseif (ARCH_AMD64)
# so we do not set the flags to avoid core dump in old machines
option (TIFLASH_ENABLE_AVX_SUPPORT "Use AVX/AVX2 instructions on x86_64" ON)
option (TIFLASH_ENABLE_AVX512_SUPPORT "Use AVX512 instructions on x86_64" ON)

# `haswell` was released since 2013 with cpu feature avx2, bmi2. It's a practical arch for optimizer
option (TIFLASH_ENABLE_ARCH_HASWELL_SUPPORT "Use instructions based on architecture `haswell` on x86_64" ON)

Expand Down
2 changes: 2 additions & 0 deletions contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,5 @@ add_subdirectory(simdjson)
add_subdirectory(fastpforlib)

add_subdirectory(usearch-cmake)

add_subdirectory(simsimd-cmake)
1 change: 1 addition & 0 deletions contrib/simsimd
Submodule simsimd added at 3e2193
13 changes: 13 additions & 0 deletions contrib/simsimd-cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set(SIMSIMD_PROJECT_DIR "${TiFlash_SOURCE_DIR}/contrib/simsimd")
set(SIMSIMD_SOURCE_DIR "${SIMSIMD_PROJECT_DIR}/include")

add_library(_simsimd INTERFACE)

if (NOT EXISTS "${SIMSIMD_SOURCE_DIR}/simsimd/simsimd.h")
message (FATAL_ERROR "submodule contrib/simsimd not found")
endif()

target_include_directories(_simsimd SYSTEM INTERFACE
${SIMSIMD_SOURCE_DIR})

add_library(tiflash_contrib::simsimd ALIAS _simsimd)
2 changes: 1 addition & 1 deletion contrib/usearch-cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if (NOT EXISTS "${USEARCH_SOURCE_DIR}/usearch/index.hpp")
endif ()

target_include_directories(_usearch SYSTEM INTERFACE
${USEARCH_PROJECT_DIR}/simsimd/include
# ${USEARCH_PROJECT_DIR}/simsimd/include # Use our simsimd
${USEARCH_PROJECT_DIR}/fp16/include
${USEARCH_SOURCE_DIR})

Expand Down
18 changes: 16 additions & 2 deletions dbms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ add_headers_and_sources(dbms src/Client)
add_headers_only(dbms src/Flash/Coprocessor)
add_headers_only(dbms src/Server)

add_headers_and_sources(tiflash_vector_search src/VectorSearch)

check_then_add_sources_compile_flag (
TIFLASH_ENABLE_ARCH_HASWELL_SUPPORT
"${TIFLASH_COMPILER_ARCH_HASWELL_FLAG}"
Expand Down Expand Up @@ -203,13 +205,25 @@ target_link_libraries (tiflash_common_io
)

target_include_directories (tiflash_common_io BEFORE PRIVATE ${kvClient_SOURCE_DIR}/include)
target_compile_definitions(tiflash_common_io PUBLIC -DTIFLASH_SOURCE_PREFIX=\"${TiFlash_SOURCE_DIR}\")
target_compile_definitions (tiflash_common_io PUBLIC -DTIFLASH_SOURCE_PREFIX=\"${TiFlash_SOURCE_DIR}\")

add_library(tiflash_vector_search
${tiflash_vector_search_headers}
${tiflash_vector_search_sources}
)
target_link_libraries(tiflash_vector_search
tiflash_contrib::usearch
tiflash_contrib::simsimd

fmt
)

target_link_libraries (dbms
${OPENSSL_CRYPTO_LIBRARY}
${BTRIE_LIBRARIES}
absl::synchronization
tiflash_contrib::usearch
tiflash_contrib::aws_s3
tiflash_vector_search

etcdpb
tiflash_parsers
Expand Down
13 changes: 13 additions & 0 deletions dbms/src/Common/TiFlashBuildInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <Common/TiFlashBuildInfo.h>
#include <Common/config.h>
#include <Common/config_version.h>
#include <VectorSearch/DistanceSIMDFeatures.h>
#include <VectorSearch/SIMDFeatures.h>
#include <common/config_common.h>
#include <common/logger_useful.h>
#include <fmt/core.h>
Expand Down Expand Up @@ -140,6 +142,17 @@ String getEnabledFeatures()
"fdo",
#endif
};
{
auto f = DB::DM::VectorIndexHNSWSIMDFeatures::get();
for (const auto & feature : f)
features.push_back(feature);
}
{
auto f = DB::VectorDistanceSIMDFeatures::get();
for (const auto & feature : f)
features.push_back(feature);
}

return fmt::format("{}", fmt::join(features.begin(), features.end(), " "));
}
// clang-format on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class S3LockServiceTest : public DB::base::TiFlashStorageTestBasic
#define CHECK_S3_ENABLED \
if (!is_s3_test_enabled) \
{ \
const auto * t = ::testing::UnitTest::GetInstance()->current_test_info(); \
const auto * t = ::testing::UnitTest::GetInstance() -> current_test_info(); \
LOG_INFO(log, "{}.{} is skipped because S3ClientFactory is not inited.", t->test_case_name(), t->name()); \
return; \
}
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Functions/FunctionsVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

class FunctionsCastVectorFloat32AsString : public IFunction
Expand Down
55 changes: 33 additions & 22 deletions dbms/src/Functions/tests/gtest_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,29 +203,40 @@ TEST_F(Vector, CosineDistance)
try
{
ASSERT_COLUMN_EQ(
createColumn<Nullable<Float64>>({0.0, std::nullopt, 0.0, 1.0, 2.0, 0.0, 2.0, std::nullopt}),
createColumn<Nullable<Float64>>(
{0.0,
1.0, // CosDistance to (0,0) cannot be calculated, clapped to 1.0
0.0,
1.0,
2.0,
0.0,
2.0,
std::nullopt}),
executeFunction(
"vecCosineDistance",
createColumn<Array>(
std::make_tuple(std::make_shared<DataTypeFloat32>()), //
{Array{1.0, 2.0},
Array{1.0, 2.0},
Array{1.0, 1.0},
Array{1.0, 0.0},
Array{1.0, 1.0},
Array{1.0, 1.0},
Array{1.0, 1.0},
Array{3e38}}),
createColumn<Array>(
std::make_tuple(std::make_shared<DataTypeFloat32>()), //
{Array{2.0, 4.0},
Array{0.0, 0.0},
Array{1.0, 1.0},
Array{0.0, 2.0},
Array{-1.0, -1.0},
Array{1.1, 1.1},
Array{-1.1, -1.1},
Array{3e38}})));
"tidbRoundWithFrac",
executeFunction(
"vecCosineDistance",
createColumn<Array>(
std::make_tuple(std::make_shared<DataTypeFloat32>()), //
{Array{1.0, 2.0},
Array{1.0, 2.0},
Array{1.0, 1.0},
Array{1.0, 0.0},
Array{1.0, 1.0},
Array{1.0, 1.0},
Array{1.0, 1.0},
Array{3e38}}),
createColumn<Array>(
std::make_tuple(std::make_shared<DataTypeFloat32>()), //
{Array{2.0, 4.0},
Array{0.0, 0.0},
Array{1.0, 1.0},
Array{0.0, 2.0},
Array{-1.0, -1.0},
Array{1.1, 1.1},
Array{-1.1, -1.1},
Array{3e38}})),
createConstColumn<int>(8, 1)));

ASSERT_THROW(
executeFunction(
Expand Down
12 changes: 5 additions & 7 deletions dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

#include <algorithm>
#include <ext/scope_guard.h>
#include <usearch/index.hpp>
#include <usearch/index_plugins.hpp>

namespace DB::ErrorCodes
{
Expand Down Expand Up @@ -192,16 +190,16 @@ std::vector<VectorIndexBuilder::Key> VectorIndexHNSWViewer::search(
std::atomic<size_t> discarded_nodes = 0;
std::atomic<bool> has_exception_in_search = false;

// The non-valid rows should be discarded by this lambda
auto predicate = [&](typename USearchImplType::member_cref_t const & member) {
// The non-valid rows should be discarded by this lambda.
auto predicate = [&](const Key & key) {
// Must catch exceptions in the predicate, because search runs on other threads.
try
{
// Note: We don't increase the thread_local perf, because search runs on other threads.
visited_nodes++;
if (!valid_rows[member.key])
if (!valid_rows[key])
discarded_nodes++;
return valid_rows[member.key];
return valid_rows[key];
}
catch (...)
{
Expand All @@ -215,7 +213,7 @@ std::vector<VectorIndexBuilder::Key> VectorIndexHNSWViewer::search(
SCOPE_EXIT({ GET_METRIC(tiflash_vector_index_duration, type_search).Observe(w.elapsedSeconds()); });

// TODO(vector-index): Support efSearch.
auto result = index.search( //
auto result = index.filtered_search( //
reinterpret_cast<const Float32 *>(query_info->ref_vec_f32().data() + sizeof(UInt32)),
query_info->top_k(),
predicate);
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <Storages/DeltaMerge/File/dtpb/dmfile.pb.h>
#include <Storages/DeltaMerge/Index/VectorIndex.h>
#include <Storages/DeltaMerge/Index/VectorIndexHNSW/usearch_index_dense.h>
#include <VectorSearch/USearch.h>

namespace DB::DM
{
Expand Down
88 changes: 54 additions & 34 deletions dbms/src/TiDB/Decode/Vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <IO/Buffer/WriteBufferFromString.h>
#include <IO/WriteHelpers.h>
#include <TiDB/Decode/Vector.h>
#include <VectorSearch/simdsimd-internals.h>

#include <compare>

Expand Down Expand Up @@ -50,15 +51,25 @@ Float64 VectorFloat32Ref::l2SquaredDistance(VectorFloat32Ref b) const
{
checkDims(b);

Float32 distance = 0.0;
Float32 diff;
static simsimd_metric_punned_t metric = nullptr;
static std::once_flag init_flag;

for (size_t i = 0, i_max = size(); i < i_max; ++i)
{
// Hope this can be vectorized.
diff = elements[i] - b[i];
distance += diff * diff;
}
std::call_once(init_flag, []() {
simsimd_capability_t used_capability;
simsimd_find_metric_punned(
simsimd_metric_l2sq_k,
simsimd_datatype_f32_k,
simsimd_details::simd_capabilities(),
simsimd_cap_any_k,
&metric,
&used_capability);
});

if (!metric)
return std::numeric_limits<double>::quiet_NaN();

simsimd_distance_t distance;
metric(elements, b.elements, elements_n, &distance);

return distance;
}
Expand All @@ -67,13 +78,25 @@ Float64 VectorFloat32Ref::innerProduct(VectorFloat32Ref b) const
{
checkDims(b);

Float32 distance = 0.0;
static simsimd_metric_punned_t metric = nullptr;
static std::once_flag init_flag;

for (size_t i = 0, i_max = size(); i < i_max; ++i)
{
// Hope this can be vectorized.
distance += elements[i] * b[i];
}
std::call_once(init_flag, []() {
simsimd_capability_t used_capability;
simsimd_find_metric_punned(
simsimd_metric_dot_k,
simsimd_datatype_f32_k,
simsimd_details::simd_capabilities(),
simsimd_cap_any_k,
&metric,
&used_capability);
});

if (!metric)
return std::numeric_limits<double>::quiet_NaN();

simsimd_distance_t distance;
metric(elements, b.elements, elements_n, &distance);

return distance;
}
Expand All @@ -82,30 +105,27 @@ Float64 VectorFloat32Ref::cosineDistance(VectorFloat32Ref b) const
{
checkDims(b);

Float32 distance = 0.0;
Float32 norma = 0.0;
Float32 normb = 0.0;
static simsimd_metric_punned_t metric = nullptr;
static std::once_flag init_flag;

for (size_t i = 0, i_max = size(); i < i_max; ++i)
{
// Hope this can be vectorized.
distance += elements[i] * b[i];
norma += elements[i] * elements[i];
normb += b[i] * b[i];
}
std::call_once(init_flag, []() {
simsimd_capability_t used_capability;
simsimd_find_metric_punned(
simsimd_metric_cos_k,
simsimd_datatype_f32_k,
simsimd_details::simd_capabilities(),
simsimd_cap_any_k,
&metric,
&used_capability);
});

Float64 similarity
= static_cast<Float64>(distance) / std::sqrt(static_cast<Float64>(norma) * static_cast<Float64>(normb));
if (!metric)
return std::numeric_limits<double>::quiet_NaN();

if (std::isnan(similarity))
{
// When norma or normb is zero, distance is zero, and similarity is NaN.
// similarity can not be Inf in this case.
return std::nan("");
}
simsimd_distance_t distance;
metric(elements, b.elements, elements_n, &distance);

similarity = std::clamp(similarity, -1.0, 1.0);
return 1.0 - similarity;
return distance;
}

Float64 VectorFloat32Ref::l1Distance(VectorFloat32Ref b) const
Expand Down
Loading

0 comments on commit 5f08ae6

Please sign in to comment.