Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Support GetVectorByIds for ANNOY, HNSW and all faiss index types (#791)
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain authored Apr 3, 2023
1 parent f2d4b39 commit 1811195
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 71 deletions.
17 changes: 8 additions & 9 deletions src/index/annoy/annoy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,22 @@ class AnnoyIndexNode : public IndexNode {

auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto p_ids = dataset.GetIds();
auto ids = dataset.GetIds();

float* p_x = nullptr;
float* data = nullptr;
try {
p_x = new (std::nothrow) float[dim * rows];
data = new float[dim * rows];
for (int64_t i = 0; i < rows; i++) {
int64_t id = p_ids[i];
int64_t id = ids[i];
assert(id >= 0 && id < index_->get_n_items());
index_->get_item(id, p_x + i * dim);
index_->get_item(id, data + i * dim);
}
return GenResultDataSet(data);
} catch (const std::exception& e) {
std::unique_ptr<float> auto_del(p_x);
LOG_KNOWHERE_WARNING_ << "error in annoy, " << e.what();
std::unique_ptr<float> auto_del(data);
LOG_KNOWHERE_WARNING_ << "error in annoy: " << e.what();
return unexpected(Status::annoy_inner_error);
}

return GenResultDataSet(p_x);
}

expected<DataSetPtr, Status>
Expand Down
42 changes: 22 additions & 20 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class FlatIndexNode : public IndexNode {
public:
FlatIndexNode(const Object&) : index_(nullptr) {
static_assert(std::is_same<T, faiss::IndexFlat>::value || std::is_same<T, faiss::IndexBinaryFlat>::value,
"not suppprt.");
"not support");
pool_ = ThreadPool::GetGlobalThreadPool();
}

Expand All @@ -45,7 +45,7 @@ class FlatIndexNode : public IndexNode {
const FlatConfig& f_cfg = static_cast<const FlatConfig&>(cfg);
auto metric = Str2FaissMetricType(f_cfg.metric_type);
if (!metric.has_value()) {
LOG_KNOWHERE_WARNING_ << "please check metric type, " << f_cfg.metric_type;
LOG_KNOWHERE_WARNING_ << "please check metric type: " << f_cfg.metric_type;
return metric.error();
}
index_ = std::make_unique<T>(dataset.GetDim(), metric.value());
Expand Down Expand Up @@ -112,7 +112,7 @@ class FlatIndexNode : public IndexNode {
} catch (const std::exception& e) {
std::unique_ptr<int64_t[]> auto_delete_ids(ids);
std::unique_ptr<float[]> auto_delete_dis(distances);
LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what();
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
return unexpected(Status::faiss_inner_error);
}

Expand All @@ -122,7 +122,7 @@ class FlatIndexNode : public IndexNode {
expected<DataSetPtr, Status>
RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override {
if (!index_) {
LOG_KNOWHERE_WARNING_ << "range search on empty index.";
LOG_KNOWHERE_WARNING_ << "range search on empty index";
return unexpected(Status::empty_index);
}

Expand Down Expand Up @@ -174,7 +174,7 @@ class FlatIndexNode : public IndexNode {
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what();
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
return unexpected(Status::faiss_inner_error);
}

Expand All @@ -183,32 +183,34 @@ class FlatIndexNode : public IndexNode {

expected<DataSetPtr, Status>
GetVectorByIds(const DataSet& dataset, const Config& cfg) const override {
auto nq = dataset.GetRows();
auto rows = dataset.GetRows();
auto dim = dataset.GetDim();
auto in_ids = dataset.GetIds();
auto ids = dataset.GetIds();
if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
float* data = nullptr;
try {
float* xq = new (std::nothrow) float[nq * dim];
for (int64_t i = 0; i < nq; i++) {
int64_t id = in_ids[i];
index_->reconstruct(id, xq + i * dim);
data = new float[rows * dim];
for (int64_t i = 0; i < rows; i++) {
index_->reconstruct(ids[i], data + i * dim);
}
return GenResultDataSet(xq);
return GenResultDataSet(data);
} catch (const std::exception& e) {
std::unique_ptr<float[]> auto_del(data);
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return unexpected(Status::faiss_inner_error);
}
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
uint8_t* data = nullptr;
try {
uint8_t* xq = new (std::nothrow) uint8_t[nq * dim / 8];
for (int64_t i = 0; i < nq; i++) {
int64_t id = in_ids[i];
index_->reconstruct(id, xq + i * dim / 8);
data = new uint8_t[rows * dim / 8];
for (int64_t i = 0; i < rows; i++) {
index_->reconstruct(ids[i], data + i * dim / 8);
}
return GenResultDataSet(xq);
return GenResultDataSet(data);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what();
std::unique_ptr<uint8_t[]> auto_del(data);
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
return unexpected(Status::faiss_inner_error);
}
}
Expand Down Expand Up @@ -241,7 +243,7 @@ class FlatIndexNode : public IndexNode {
}
return Status::success;
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what();
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
return Status::faiss_inner_error;
}
}
Expand Down Expand Up @@ -315,7 +317,7 @@ class FlatIndexNode : public IndexNode {
return knowhere::IndexEnum::INDEX_FAISS_IDMAP;
}
if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) {
return knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT;
return knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP;
}
}

Expand Down
26 changes: 13 additions & 13 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class HnswIndexNode : public IndexNode {
} else if (hnsw_cfg.metric_type == metric::IP) {
space = new (std::nothrow) hnswlib::InnerProductSpace(dim);
} else {
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw, " << hnsw_cfg.metric_type;
LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type;
return Status::invalid_metric_type;
}
auto index =
Expand All @@ -66,7 +66,7 @@ class HnswIndexNode : public IndexNode {
}
if (this->index_) {
delete this->index_;
LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index.";
LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index";
}
this->index_ = index;
return Status::success;
Expand Down Expand Up @@ -160,7 +160,7 @@ class HnswIndexNode : public IndexNode {
expected<DataSetPtr, Status>
RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override {
if (!index_) {
LOG_KNOWHERE_WARNING_ << "range search on empty index.";
LOG_KNOWHERE_WARNING_ << "range search on empty index";
return unexpected(Status::empty_index);
}

Expand Down Expand Up @@ -243,26 +243,26 @@ class HnswIndexNode : public IndexNode {
auto rows = dataset.GetRows();
auto ids = dataset.GetIds();

float* p_x = nullptr;
float* data = nullptr;
try {
p_x = new float[dim * rows];
data = new float[dim * rows];
for (int64_t i = 0; i < rows; i++) {
int64_t id = ids[i];
assert(id >= 0 && id < (int64_t)index_->cur_element_count);
memcpy(p_x + i * dim, index_->getDataByInternalId(id), dim * sizeof(float));
std::copy_n((float*)index_->getDataByInternalId(id), dim, data + i * dim);
}
return GenResultDataSet(data);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what();
std::unique_ptr<float> auto_delete_px(p_x);
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
std::unique_ptr<float> auto_del(data);
return unexpected(Status::hnsw_inner_error);
}
return GenResultDataSet(p_x);
}

expected<DataSetPtr, Status>
GetIndexMeta(const Config& cfg) const override {
if (!index_) {
LOG_KNOWHERE_WARNING_ << "get index meta on empty index.";
LOG_KNOWHERE_WARNING_ << "get index meta on empty index";
return unexpected(Status::empty_index);
}

Expand Down Expand Up @@ -299,7 +299,7 @@ class HnswIndexNode : public IndexNode {
std::shared_ptr<uint8_t[]> data(writer.data_);
binset.Append("HNSW", data, writer.rp);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what();
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
return Status::hnsw_inner_error;
}
return Status::success;
Expand All @@ -321,7 +321,7 @@ class HnswIndexNode : public IndexNode {
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<float>(space);
index_->loadIndex(reader);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what();
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
return Status::hnsw_inner_error;
}
return Status::success;
Expand All @@ -337,7 +337,7 @@ class HnswIndexNode : public IndexNode {
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<float>(space);
index_->loadIndex(filename, config);
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what();
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
return Status::hnsw_inner_error;
}
return Status::success;
Expand Down
Loading

0 comments on commit 1811195

Please sign in to comment.