Skip to content

Commit

Permalink
feat(search): HNSW
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg committed Sep 6, 2023
1 parent 4e393cf commit 56fe552
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 24 deletions.
13 changes: 13 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ add_third_party(
LIB "none"
)

add_third_party(
hnswlib
URL https://github.com/nmslib/hnswlib/archive/refs/tags/v0.7.0.tar.gz

INSTALL_COMMAND cp -r <SOURCE_DIR>/hnswlib/* ${THIRD_PARTY_LIB_DIR}/include
LIB "none"
)

add_library(TRDP::jsoncons INTERFACE IMPORTED)
add_dependencies(TRDP::jsoncons jsoncons_project)
set_target_properties(TRDP::jsoncons PROPERTIES
Expand All @@ -102,6 +110,11 @@ add_dependencies(TRDP::croncpp croncpp_project)
set_target_properties(TRDP::croncpp PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${CRONCPP_INCLUDE_DIR}")

add_library(TRDP::hnswlib INTERFACE IMPORTED)
add_dependencies(TRDP::hnswlib hnswlib_project)
set_target_properties(TRDP::hnswlib PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${HNSWLIB_INCLUDE_DIR}")

Message(STATUS "THIRD_PARTY_LIB_DIR ${THIRD_PARTY_LIB_DIR}")


Expand Down
3 changes: 1 addition & 2 deletions src/core/search/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ find_package(ICU REQUIRED COMPONENTS uc i18n)
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector_utils.cc compressed_sorted_set.cc
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)

target_link_libraries(query_parser ICU::uc ICU::i18n)
target_link_libraries(query_parser base absl::strings TRDP::reflex TRDP::hnswlib ICU::uc ICU::i18n)

target_link_libraries(query_parser base absl::strings TRDP::reflex)
cxx_test(compressed_sorted_set_test query_parser LABELS DFLY)
cxx_test(search_parser_test query_parser LABELS DFLY)
cxx_test(search_test query_parser LABELS DFLY)
105 changes: 99 additions & 6 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
#include <absl/strings/ascii.h>
#include <absl/strings/numbers.h>
#include <absl/strings/str_split.h>
#include <hnswalg.h>
#include <hnswlib.h>
#include <space_ip.h>
#include <space_l2.h>
#include <unicode/brkiter.h>
#include <unicode/unistr.h>

#include <algorithm>
#include <cctype>
#include <regex>
#include <variant>

#include "base/logging.h"

Expand Down Expand Up @@ -151,10 +156,18 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
return NormalizeTags(value);
}

VectorIndex::VectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim}, entries_{} {
BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
}

void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
return {dim_, sim_};
}

FlatVectorIndex::FlatVectorIndex(size_t dim, VectorSimilarity sim)
: BaseVectorIndex{dim, sim}, entries_{} {
}

void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_);
Expand All @@ -166,16 +179,96 @@ void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
}

void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
// noop
}

const float* VectorIndex::Get(DocId doc) const {
const float* FlatVectorIndex::Get(DocId doc) const {
return &entries_[doc * dim_];
}

std::pair<size_t /*dim*/, VectorSimilarity> VectorIndex::Info() const {
return {dim_, sim_};
struct HnswlibAdapter {
HnswlibAdapter(size_t dim, VectorSimilarity sim, size_t cap)
: space_{MakeSpace(dim, sim)}, world_{GetSpacePtr(), cap} {
}

void Add(float* data, DocId id) {
world_.addPoint(data, id);
}

void Remove(DocId id) {
world_.markDelete(id);
}

vector<pair<float, DocId>> Knn(float* target, size_t k) {
return QueueToVec(world_.searchKnn(target, k));
}

vector<pair<float, DocId>> Knn(float* target, size_t k, const vector<DocId>& allowed) {
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
virtual bool operator()(hnswlib::labeltype id) {
return binary_search(allowed->begin(), allowed->end(), id);
}

BinsearchFilter(const vector<DocId>* allowed) : allowed{allowed} {
}
const vector<DocId>* allowed;
};

BinsearchFilter filter{&allowed};
return QueueToVec(world_.searchKnn(target, k, &filter));
}

private:
using SpaceUnion = std::variant<hnswlib::L2Space, hnswlib::InnerProductSpace>;

static SpaceUnion MakeSpace(size_t dim, VectorSimilarity sim) {
if (sim == VectorSimilarity::L2)
return hnswlib::L2Space{dim};
else
return hnswlib::InnerProductSpace{dim};
}

hnswlib::SpaceInterface<float>* GetSpacePtr() {
return visit([](auto& space) -> hnswlib::SpaceInterface<float>* { return &space; }, space_);
}

template <typename Q> static vector<pair<float, DocId>> QueueToVec(Q queue) {
vector<pair<float, DocId>> out;
while (!queue.empty()) {
out.push_back(queue.top());
queue.pop();
}
return out;
}

SpaceUnion space_;
hnswlib::HierarchicalNSW<float> world_;
};

HnswVectorIndex::HnswVectorIndex(size_t dim, VectorSimilarity sim, size_t capacity)
: BaseVectorIndex{dim, sim}, adapter_{make_unique<HnswlibAdapter>(dim, sim, capacity)} {
}

HnswVectorIndex::~HnswVectorIndex() {
}

void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
auto [ptr, size] = doc->GetVector(field);
if (size == dim_)
adapter_->Add(ptr.get(), id);
}

std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k) const {
return adapter_->Knn(target, k);
}
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
const std::vector<DocId>& allowed) const {
return adapter_->Knn(target, k, allowed);
}

void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
adapter_->Remove(id);
}

} // namespace dfly::search
35 changes: 30 additions & 5 deletions src/core/search/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>

#include <map>
#include <memory>
#include <optional>
#include <vector>

Expand Down Expand Up @@ -54,21 +55,45 @@ struct TagIndex : public BaseStringIndex {
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
};

struct BaseVectorIndex : public BaseIndex {
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;

protected:
BaseVectorIndex(size_t dim, VectorSimilarity sim);

size_t dim_;
VectorSimilarity sim_;
};

// Index for vector fields.
// Only supports lookup by id.
struct VectorIndex : public BaseIndex {
VectorIndex(size_t dim, VectorSimilarity sim);
struct FlatVectorIndex : public BaseVectorIndex {
FlatVectorIndex(size_t dim, VectorSimilarity sim);

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;

const float* Get(DocId doc) const;
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;

private:
size_t dim_;
VectorSimilarity sim_;
std::vector<float> entries_;
};

struct HnswlibAdapter;

struct HnswVectorIndex : public BaseVectorIndex {
HnswVectorIndex(size_t dim, VectorSimilarity sim, size_t capacity);
~HnswVectorIndex();

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;

std::vector<std::pair<float, DocId>> Knn(float* target, size_t k) const;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k,
const std::vector<DocId>& allowed) const;

private:
std::unique_ptr<HnswlibAdapter> adapter_;
};

} // namespace dfly::search
53 changes: 42 additions & 11 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,10 @@ struct BasicSearch {
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}

// [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit`
IndexResult Search(const AstKnnNode& knn, string_view active_field) {
DCHECK(active_field.empty());
auto sub_results = SearchGeneric(*knn.filter, active_field);

auto* vec_index = GetIndex<VectorIndex>(knn.field);
void SearchKnnFlat(const AstKnnNode& knn, IndexResult&& sub_results) {
auto* vec_index = GetIndex<FlatVectorIndex>(knn.field);
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
return IndexResult{};
return;

distances_.reserve(sub_results.Size());
auto cb = [&](auto* set) {
Expand All @@ -295,11 +291,34 @@ struct BasicSearch {

size_t prefix_size = min(knn.limit, distances_.size());
partial_sort(distances_.begin(), distances_.begin() + prefix_size, distances_.end());
distances_.resize(prefix_size);
}

vector<DocId> out(prefix_size);
for (size_t i = 0; i < out.size(); i++)
out[i] = distances_[i].second;
void SearchKnnHnsw(const AstKnnNode& knn, IndexResult&& sub_results) {
auto* vec_index = GetIndex<HnswVectorIndex>(knn.field);
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
return;

if (indices_->GetAllDocs().size() == sub_results.Size())
distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit);
else
distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, sub_results.Take());
}

// [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit`
IndexResult Search(const AstKnnNode& knn, string_view active_field) {
DCHECK(active_field.empty());
auto sub_results = SearchGeneric(*knn.filter, active_field);

distances_.clear();
if (indices_->GetSchema().fields.at(knn.field).hnsw_capacity.has_value())
SearchKnnHnsw(knn, std::move(sub_results));
else
SearchKnnFlat(knn, std::move(sub_results));

vector<DocId> out(distances_.size());
for (size_t i = 0; i < distances_.size(); i++)
out[i] = distances_[i].second;
return out;
}

Expand Down Expand Up @@ -364,7 +383,15 @@ FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, i
indices_[field_ident] = make_unique<NumericIndex>();
break;
case SchemaField::VECTOR:
indices_[field_ident] = make_unique<VectorIndex>(field_info.knn_dim, field_info.knn_sim);
unique_ptr<BaseVectorIndex> vector_index;

if (auto capacity = field_info.hnsw_capacity; capacity)
vector_index =
make_unique<HnswVectorIndex>(field_info.knn_dim, field_info.knn_sim, *capacity);
else
vector_index = make_unique<FlatVectorIndex>(field_info.knn_dim, field_info.knn_sim);

indices_[field_ident] = std::move(vector_index);
break;
}
}
Expand Down Expand Up @@ -411,6 +438,10 @@ const vector<DocId>& FieldIndices::GetAllDocs() const {
return all_ids_;
}

const Schema& FieldIndices::GetSchema() const {
return schema_;
}

SearchAlgorithm::SearchAlgorithm() = default;
SearchAlgorithm::~SearchAlgorithm() = default;

Expand Down
3 changes: 3 additions & 0 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct SchemaField {

size_t knn_dim = 0u; // dimension of knn vectors
VectorSimilarity knn_sim = VectorSimilarity::L2; // similarity type
std::optional<size_t> hnsw_capacity; // if set, capacity for hnsw world
};

// Describes the fields of an index
Expand All @@ -52,6 +53,8 @@ class FieldIndices {
std::vector<TextIndex*> GetAllTextIndices() const;
const std::vector<DocId>& GetAllDocs() const;

const Schema& GetSchema() const;

private:
Schema schema_;
std::vector<DocId> all_ids_;
Expand Down
1 change: 1 addition & 0 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ std::string ToBytes(absl::Span<const float> vec) {
TEST_F(SearchParserTest, SimpleKnn) {
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
schema.fields["pos"].knn_dim = 1;
schema.fields["pos"].hnsw_capacity = 120;
FieldIndices indices{schema};

// Place points on a straight line
Expand Down

0 comments on commit 56fe552

Please sign in to comment.