Skip to content

Commit

Permalink
feat(search): HNSW (#1799)
Browse files Browse the repository at this point in the history
* feat(search): HNSW

---------

Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg authored Sep 8, 2023
1 parent 55737a6 commit e69f182
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 63 deletions.
13 changes: 13 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ add_third_party(
BUILD_IN_SOURCE 1
)

add_third_party(
hnswlib
URL https://github.com/nmslib/hnswlib/archive/refs/tags/v0.7.0.tar.gz

INSTALL_COMMAND cp -RT <SOURCE_DIR>/hnswlib ${THIRD_PARTY_LIB_DIR}/hnswlib/include/hnswlib
LIB "none"
)

add_library(TRDP::jsoncons INTERFACE IMPORTED)
add_dependencies(TRDP::jsoncons jsoncons_project)
set_target_properties(TRDP::jsoncons PROPERTIES
Expand All @@ -109,6 +117,11 @@ add_dependencies(TRDP::croncpp croncpp_project)
set_target_properties(TRDP::croncpp PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${CRONCPP_INCLUDE_DIR}")

add_library(TRDP::hnswlib INTERFACE IMPORTED)
add_dependencies(TRDP::hnswlib hnswlib_project)
set_target_properties(TRDP::hnswlib PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${HNSWLIB_INCLUDE_DIR}")

Message(STATUS "THIRD_PARTY_LIB_DIR ${THIRD_PARTY_LIB_DIR}")


Expand Down
2 changes: 1 addition & 1 deletion src/core/search/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ cur_gen_dir(gen_dir)
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector_utils.cc compressed_sorted_set.cc
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)

target_link_libraries(query_parser base absl::strings TRDP::reflex TRDP::uni-algo)
target_link_libraries(query_parser base absl::strings TRDP::reflex TRDP::uni-algo TRDP::hnswlib)

cxx_test(compressed_sorted_set_test query_parser LABELS DFLY)
cxx_test(search_parser_test query_parser LABELS DFLY)
Expand Down
105 changes: 99 additions & 6 deletions src/core/search/indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

#define UNI_ALGO_DISABLE_NFKC_NFKD

#include <hnswlib/hnswalg.h>
#include <hnswlib/hnswlib.h>
#include <hnswlib/space_ip.h>
#include <hnswlib/space_l2.h>
#include <uni_algo/case.h>
#include <uni_algo/ranges_word.h>

Expand Down Expand Up @@ -106,10 +110,18 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
return NormalizeTags(value);
}

VectorIndex::VectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim}, entries_{} {
BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} {
}

void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
return {dim_, sim_};
}

FlatVectorIndex::FlatVectorIndex(size_t dim, VectorSimilarity sim)
: BaseVectorIndex{dim, sim}, entries_{} {
}

void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_);
Expand All @@ -121,16 +133,97 @@ void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
}

void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
// noop
}

const float* VectorIndex::Get(DocId doc) const {
const float* FlatVectorIndex::Get(DocId doc) const {
return &entries_[doc * dim_];
}

std::pair<size_t /*dim*/, VectorSimilarity> VectorIndex::Info() const {
return {dim_, sim_};
struct HnswlibAdapter {
HnswlibAdapter(size_t dim, VectorSimilarity sim, size_t cap)
: space_{MakeSpace(dim, sim)}, world_{GetSpacePtr(), cap} {
}

void Add(float* data, DocId id) {
world_.addPoint(data, id);
}

void Remove(DocId id) {
world_.markDelete(id);
}

vector<pair<float, DocId>> Knn(float* target, size_t k) {
return QueueToVec(world_.searchKnn(target, k));
}

vector<pair<float, DocId>> Knn(float* target, size_t k, const vector<DocId>& allowed) {
struct BinsearchFilter : hnswlib::BaseFilterFunctor {
virtual bool operator()(hnswlib::labeltype id) {
return binary_search(allowed->begin(), allowed->end(), id);
}

BinsearchFilter(const vector<DocId>* allowed) : allowed{allowed} {
}
const vector<DocId>* allowed;
};

BinsearchFilter filter{&allowed};
return QueueToVec(world_.searchKnn(target, k, &filter));
}

private:
using SpaceUnion = std::variant<hnswlib::L2Space, hnswlib::InnerProductSpace>;

static SpaceUnion MakeSpace(size_t dim, VectorSimilarity sim) {
if (sim == VectorSimilarity::L2)
return hnswlib::L2Space{dim};
else
return hnswlib::InnerProductSpace{dim};
}

hnswlib::SpaceInterface<float>* GetSpacePtr() {
return visit([](auto& space) -> hnswlib::SpaceInterface<float>* { return &space; }, space_);
}

template <typename Q> static vector<pair<float, DocId>> QueueToVec(Q queue) {
vector<pair<float, DocId>> out(queue.size());
size_t idx = out.size();
while (!queue.empty()) {
out[--idx] = queue.top();
queue.pop();
}
return out;
}

SpaceUnion space_;
hnswlib::HierarchicalNSW<float> world_;
};

HnswVectorIndex::HnswVectorIndex(size_t dim, VectorSimilarity sim, size_t capacity)
: BaseVectorIndex{dim, sim}, adapter_{make_unique<HnswlibAdapter>(dim, sim, capacity)} {
}

HnswVectorIndex::~HnswVectorIndex() {
}

void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
auto [ptr, size] = doc->GetVector(field);
if (size == dim_)
adapter_->Add(ptr.get(), id);
}

std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k) const {
return adapter_->Knn(target, k);
}
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
const std::vector<DocId>& allowed) const {
return adapter_->Knn(target, k, allowed);
}

void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
adapter_->Remove(id);
}

} // namespace dfly::search
35 changes: 30 additions & 5 deletions src/core/search/indices.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>

#include <map>
#include <memory>
#include <optional>
#include <vector>

Expand Down Expand Up @@ -54,21 +55,45 @@ struct TagIndex : public BaseStringIndex {
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
};

struct BaseVectorIndex : public BaseIndex {
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;

protected:
BaseVectorIndex(size_t dim, VectorSimilarity sim);

size_t dim_;
VectorSimilarity sim_;
};

// Index for vector fields.
// Only supports lookup by id.
struct VectorIndex : public BaseIndex {
VectorIndex(size_t dim, VectorSimilarity sim);
struct FlatVectorIndex : public BaseVectorIndex {
FlatVectorIndex(size_t dim, VectorSimilarity sim);

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;

const float* Get(DocId doc) const;
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;

private:
size_t dim_;
VectorSimilarity sim_;
std::vector<float> entries_;
};

struct HnswlibAdapter;

struct HnswVectorIndex : public BaseVectorIndex {
HnswVectorIndex(size_t dim, VectorSimilarity sim, size_t capacity);
~HnswVectorIndex();

void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;

std::vector<std::pair<float, DocId>> Knn(float* target, size_t k) const;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k,
const std::vector<DocId>& allowed) const;

private:
std::unique_ptr<HnswlibAdapter> adapter_;
};

} // namespace dfly::search
63 changes: 52 additions & 11 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,10 @@ struct BasicSearch {
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}

// [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit`
IndexResult Search(const AstKnnNode& knn, string_view active_field) {
DCHECK(active_field.empty());
auto sub_results = SearchGeneric(*knn.filter, active_field);

auto* vec_index = GetIndex<VectorIndex>(knn.field);
void SearchKnnFlat(const AstKnnNode& knn, IndexResult&& sub_results) {
auto* vec_index = GetIndex<FlatVectorIndex>(knn.field);
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
return IndexResult{};
return;

distances_.reserve(sub_results.Size());
auto cb = [&](auto* set) {
Expand All @@ -295,11 +291,42 @@ struct BasicSearch {

size_t prefix_size = min(knn.limit, distances_.size());
partial_sort(distances_.begin(), distances_.begin() + prefix_size, distances_.end());
distances_.resize(prefix_size);
}

vector<DocId> out(prefix_size);
for (size_t i = 0; i < out.size(); i++)
out[i] = distances_[i].second;
void SearchKnnHnsw(const AstKnnNode& knn, IndexResult&& sub_results) {
auto* vec_index = GetIndex<HnswVectorIndex>(knn.field);
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
return;

if (indices_->GetAllDocs().size() == sub_results.Size())
distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit);
else
distances_ = vec_index->Knn(knn.vec.first.get(), knn.limit, sub_results.Take());
}

// [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit`
IndexResult Search(const AstKnnNode& knn, string_view active_field) {
DCHECK(active_field.empty());
auto sub_results = SearchGeneric(*knn.filter, active_field);

const auto& schema = indices_->GetSchema();
string_view knn_field = knn.field;
if (auto it = schema.field_names.find(knn_field); it != schema.field_names.end())
knn_field = it->second;

const auto& field_info = schema.fields.at(knn_field);
DCHECK(holds_alternative<SchemaField::VectorParams>(field_info.special_params));

distances_.clear();
if (get<SchemaField::VectorParams>(field_info.special_params).use_hnsw)
SearchKnnHnsw(knn, std::move(sub_results));
else
SearchKnnFlat(knn, std::move(sub_results));

vector<DocId> out(distances_.size());
for (size_t i = 0; i < distances_.size(); i++)
out[i] = distances_[i].second;
return out;
}

Expand Down Expand Up @@ -364,7 +391,17 @@ FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, i
indices_[field_ident] = make_unique<NumericIndex>();
break;
case SchemaField::VECTOR:
indices_[field_ident] = make_unique<VectorIndex>(field_info.knn_dim, field_info.knn_sim);
unique_ptr<BaseVectorIndex> vector_index;

DCHECK(holds_alternative<SchemaField::VectorParams>(field_info.special_params));
const auto& vparams = std::get<SchemaField::VectorParams>(field_info.special_params);

if (vparams.use_hnsw)
vector_index = make_unique<HnswVectorIndex>(vparams.dim, vparams.sim, vparams.capacity);
else
vector_index = make_unique<FlatVectorIndex>(vparams.dim, vparams.sim);

indices_[field_ident] = std::move(vector_index);
break;
}
}
Expand Down Expand Up @@ -411,6 +448,10 @@ const vector<DocId>& FieldIndices::GetAllDocs() const {
return all_ids_;
}

const Schema& FieldIndices::GetSchema() const {
return schema_;
}

SearchAlgorithm::SearchAlgorithm() = default;
SearchAlgorithm::~SearchAlgorithm() = default;

Expand Down
16 changes: 14 additions & 2 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <optional>
#include <string>
#include <unordered_map>
#include <variant>

#include "core/search/base.h"

Expand All @@ -23,11 +24,20 @@ struct TextIndex;
struct SchemaField {
enum FieldType { TAG, TEXT, NUMERIC, VECTOR };

struct VectorParams {
bool use_hnsw = false;

size_t dim = 0u; // dimension of knn vectors
VectorSimilarity sim = VectorSimilarity::L2; // similarity type
size_t capacity = 1000; // initial capacity for hnsw world
};

using ParamsVariant = std::variant<std::monostate, VectorParams>;

FieldType type;
std::string short_name; // equal to ident if none provided

size_t knn_dim = 0u; // dimension of knn vectors
VectorSimilarity knn_sim = VectorSimilarity::L2; // similarity type
ParamsVariant special_params{std::monostate{}};
};

// Describes the fields of an index
Expand All @@ -52,6 +62,8 @@ class FieldIndices {
std::vector<TextIndex*> GetAllTextIndices() const;
const std::vector<DocId>& GetAllDocs() const;

const Schema& GetSchema() const;

private:
Schema schema_;
std::vector<DocId> all_ids_;
Expand Down
Loading

0 comments on commit e69f182

Please sign in to comment.