diff --git a/src/index/annoy/annoy.cc b/src/index/annoy/annoy.cc index aed5ed1e5..f3d05ee48 100644 --- a/src/index/annoy/annoy.cc +++ b/src/index/annoy/annoy.cc @@ -139,23 +139,22 @@ class AnnoyIndexNode : public IndexNode { auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); - auto p_ids = dataset.GetIds(); + auto ids = dataset.GetIds(); - float* p_x = nullptr; + float* data = nullptr; try { - p_x = new (std::nothrow) float[dim * rows]; + data = new float[dim * rows]; for (int64_t i = 0; i < rows; i++) { - int64_t id = p_ids[i]; + int64_t id = ids[i]; assert(id >= 0 && id < index_->get_n_items()); - index_->get_item(id, p_x + i * dim); + index_->get_item(id, data + i * dim); } + return GenResultDataSet(data); } catch (const std::exception& e) { - std::unique_ptr auto_del(p_x); - LOG_KNOWHERE_WARNING_ << "error in annoy, " << e.what(); + std::unique_ptr auto_del(data); + LOG_KNOWHERE_WARNING_ << "error in annoy: " << e.what(); return unexpected(Status::annoy_inner_error); } - - return GenResultDataSet(p_x); } expected diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index f3855cb5d..8b304fdd1 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -26,7 +26,7 @@ class FlatIndexNode : public IndexNode { public: FlatIndexNode(const Object&) : index_(nullptr) { static_assert(std::is_same::value || std::is_same::value, - "not suppprt."); + "not support"); pool_ = ThreadPool::GetGlobalThreadPool(); } @@ -45,7 +45,7 @@ class FlatIndexNode : public IndexNode { const FlatConfig& f_cfg = static_cast(cfg); auto metric = Str2FaissMetricType(f_cfg.metric_type); if (!metric.has_value()) { - LOG_KNOWHERE_WARNING_ << "please check metric type, " << f_cfg.metric_type; + LOG_KNOWHERE_WARNING_ << "please check metric type: " << f_cfg.metric_type; return metric.error(); } index_ = std::make_unique(dataset.GetDim(), metric.value()); @@ -112,7 +112,7 @@ class FlatIndexNode : public IndexNode { } catch (const std::exception& e) { std::unique_ptr auto_delete_ids(ids); std::unique_ptr auto_delete_dis(distances); - LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what(); + LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return unexpected(Status::faiss_inner_error); } @@ -122,7 +122,7 @@ class FlatIndexNode : public IndexNode { expected RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { if (!index_) { - LOG_KNOWHERE_WARNING_ << "range search on empty index."; + LOG_KNOWHERE_WARNING_ << "range search on empty index"; return unexpected(Status::empty_index); } @@ -174,7 +174,7 @@ class FlatIndexNode : public IndexNode { GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what(); + LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return unexpected(Status::faiss_inner_error); } @@ -183,32 +183,34 @@ class FlatIndexNode : public IndexNode { expected GetVectorByIds(const DataSet& dataset, const Config& cfg) const override { - auto nq = dataset.GetRows(); + auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); - auto in_ids = dataset.GetIds(); + auto ids = dataset.GetIds(); if constexpr (std::is_same::value) { + float* data = nullptr; try { - float* xq = new (std::nothrow) float[nq * dim]; - for (int64_t i = 0; i < nq; i++) { - int64_t id = in_ids[i]; - index_->reconstruct(id, xq + i * dim); + data = new float[rows * dim]; + for (int64_t i = 0; i < rows; i++) { + index_->reconstruct(ids[i], data + i * dim); } - return GenResultDataSet(xq); + return GenResultDataSet(data); } catch (const std::exception& e) { + std::unique_ptr auto_del(data); LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return unexpected(Status::faiss_inner_error); } } if constexpr (std::is_same::value) { + uint8_t* data = nullptr; try { - uint8_t* xq = new (std::nothrow) uint8_t[nq * dim / 8]; - for (int64_t i = 0; i < nq; i++) { - int64_t id = in_ids[i]; - index_->reconstruct(id, xq + i * dim / 8); + data = new uint8_t[rows * dim / 8]; + for (int64_t i = 0; i < rows; i++) { + index_->reconstruct(ids[i], data + i * dim / 8); } - return GenResultDataSet(xq); + return GenResultDataSet(data); } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what(); + std::unique_ptr auto_del(data); + LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return unexpected(Status::faiss_inner_error); } } @@ -241,7 +243,7 @@ class FlatIndexNode : public IndexNode { } return Status::success; } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "error inner faiss, " << e.what(); + LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return Status::faiss_inner_error; } } @@ -315,7 +317,7 @@ class FlatIndexNode : public IndexNode { return knowhere::IndexEnum::INDEX_FAISS_IDMAP; } if constexpr (std::is_same::value) { - return knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT; + return knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP; } } diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index d7056f4b1..d843b84a2 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -55,7 +55,7 @@ class HnswIndexNode : public IndexNode { } else if (hnsw_cfg.metric_type == metric::IP) { space = new (std::nothrow) hnswlib::InnerProductSpace(dim); } else { - LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw, " << hnsw_cfg.metric_type; + LOG_KNOWHERE_WARNING_ << "metric type not support in hnsw: " << hnsw_cfg.metric_type; return Status::invalid_metric_type; } auto index = @@ -66,7 +66,7 @@ class HnswIndexNode : public IndexNode { } if (this->index_) { delete this->index_; - LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index."; + LOG_KNOWHERE_WARNING_ << "index not empty, deleted old index"; } this->index_ = index; return Status::success; @@ -160,7 +160,7 @@ class HnswIndexNode : public IndexNode { expected RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { if (!index_) { - LOG_KNOWHERE_WARNING_ << "range search on empty index."; + LOG_KNOWHERE_WARNING_ << "range search on empty index"; return unexpected(Status::empty_index); } @@ -243,26 +243,26 @@ class HnswIndexNode : public IndexNode { auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); - float* p_x = nullptr; + float* data = nullptr; try { - p_x = new float[dim * rows]; + data = new float[dim * rows]; for (int64_t i = 0; i < rows; i++) { int64_t id = ids[i]; assert(id >= 0 && id < (int64_t)index_->cur_element_count); - memcpy(p_x + i * dim, index_->getDataByInternalId(id), dim * sizeof(float)); + std::copy_n((float*)index_->getDataByInternalId(id), dim, data + i * dim); } + return GenResultDataSet(data); } catch (std::exception& e) { - LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what(); - std::unique_ptr auto_delete_px(p_x); + LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); + std::unique_ptr auto_del(data); return unexpected(Status::hnsw_inner_error); } - return GenResultDataSet(p_x); } expected GetIndexMeta(const Config& cfg) const override { if (!index_) { - LOG_KNOWHERE_WARNING_ << "get index meta on empty index."; + LOG_KNOWHERE_WARNING_ << "get index meta on empty index"; return unexpected(Status::empty_index); } @@ -299,7 +299,7 @@ class HnswIndexNode : public IndexNode { std::shared_ptr data(writer.data_); binset.Append("HNSW", data, writer.rp); } catch (std::exception& e) { - LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); return Status::hnsw_inner_error; } return Status::success; @@ -321,7 +321,7 @@ class HnswIndexNode : public IndexNode { index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(reader); } catch (std::exception& e) { - LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); return Status::hnsw_inner_error; } return Status::success; @@ -337,7 +337,7 @@ class HnswIndexNode : public IndexNode { index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(filename, config); } catch (std::exception& e) { - LOG_KNOWHERE_WARNING_ << "hnsw inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); return Status::hnsw_inner_error; } return Status::success; diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 4a03cf56f..770099801 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -45,7 +45,7 @@ class IvfIndexNode : public IndexNode { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value, - "not support."); + "not support"); pool_ = ThreadPool::GetGlobalThreadPool(); } Status @@ -256,7 +256,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { if (qzr) { delete qzr; } - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } index_ = std::move(index); @@ -300,7 +300,7 @@ IvfIndexNode::Add(const DataSet& dataset, const Config&) { } } catch (std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } return Status::success; @@ -366,7 +366,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV } catch (const std::exception& e) { delete[] ids; delete[] distances; - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return unexpected(Status::faiss_inner_error); } @@ -379,11 +379,11 @@ template expected IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { if (!this->index_) { - LOG_KNOWHERE_WARNING_ << "range search on empty index."; + LOG_KNOWHERE_WARNING_ << "range search on empty index"; return unexpected(Status::empty_index); } if (!this->index_->is_trained) { - LOG_KNOWHERE_WARNING_ << "index not trained."; + LOG_KNOWHERE_WARNING_ << "index not trained"; return unexpected(Status::index_not_trained); } @@ -451,7 +451,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi } GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return unexpected(Status::faiss_inner_error); } @@ -469,25 +469,26 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset, const Config& cfg) const } auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); - float* p_x(new (std::nothrow) float[dim * rows]); - index_->make_direct_map(true); - auto p_ids = dataset.GetIds(); + auto ids = dataset.GetIds(); + float* data = nullptr; try { + data = new float[dim * rows]; + index_->make_direct_map(true); for (int64_t i = 0; i < rows; i++) { - int64_t id = p_ids[i]; + int64_t id = ids[i]; assert(id >= 0 && id < index_->ntotal); if constexpr (std::is_same::value) { - index_->reconstruct_without_codes(id, p_x + i * dim); + index_->reconstruct_without_codes(id, data + i * dim); } else { - index_->reconstruct(id, p_x + i * dim); + index_->reconstruct(id, data + i * dim); } } + return GenResultDataSet(data); } catch (const std::exception& e) { - std::unique_ptr p_x_auto_delete(p_x); - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + std::unique_ptr auto_del(data); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return unexpected(Status::faiss_inner_error); } - return GenResultDataSet(p_x); } template <> @@ -501,28 +502,29 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset, cons } auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); - uint8_t* p_x(new (std::nothrow) uint8_t[dim * rows / 8]); - index_->make_direct_map(true); - auto p_ids = dataset.GetIds(); + auto ids = dataset.GetIds(); + uint8_t* data = nullptr; try { + data = new uint8_t[dim * rows / 8]; + index_->make_direct_map(true); for (int64_t i = 0; i < rows; i++) { - int64_t id = p_ids[i]; + int64_t id = ids[i]; assert(id >= 0 && id < index_->ntotal); - index_->reconstruct(id, p_x + i * dim / 8); + index_->reconstruct(id, data + i * dim / 8); } + return GenResultDataSet(data); } catch (const std::exception& e) { - std::unique_ptr p_x_auto_delete(p_x); - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + std::unique_ptr auto_del(data); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return unexpected(Status::faiss_inner_error); } - return GenResultDataSet(p_x); } template <> expected IvfIndexNode::GetIndexMeta(const Config& config) const { if (!index_) { - LOG_KNOWHERE_WARNING_ << "get index meta on empty index."; + LOG_KNOWHERE_WARNING_ << "get index meta on empty index"; return unexpected(Status::empty_index); } @@ -577,7 +579,7 @@ IvfIndexNode::Serialize(BinarySet& binset) const { } return Status::success; } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } } @@ -601,7 +603,7 @@ IvfIndexNode::Deserialize(const BinarySet& binset) { index_.reset(static_cast(faiss::read_index(&reader))); } } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } return Status::success; @@ -621,7 +623,7 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const LoadConf index_.reset(static_cast(faiss::read_index(filename.data(), io_flags))); } } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } return Status::success; @@ -659,7 +661,7 @@ IvfIndexNode::Deserialize(const BinarySet& binset) { curr_index += list_size; } } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error, " << e.what(); + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } return Status::success; diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc new file mode 100644 index 000000000..982c7e907 --- /dev/null +++ b/tests/ut/test_get_vector.cc @@ -0,0 +1,143 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "catch2/catch_approx.hpp" +#include "catch2/catch_test_macros.hpp" +#include "catch2/generators/catch_generators.hpp" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/factory.h" +#include "utils.h" + +TEST_CASE("Test Get Vector By Ids", "[GetVectorByIds]") { + using Catch::Approx; + + int64_t nb = 10000; + int64_t dim = 128; + int64_t seed = 42; + + auto base_gen = [&]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = knowhere::metric::L2; + json[knowhere::meta::TOPK] = 1; + return json; + }; + + auto base_bin_gen = [&]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = knowhere::metric::HAMMING; + json[knowhere::meta::TOPK] = 1; + return json; + }; + + auto annoy_gen = [&base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::N_TREES] = 16; + json[knowhere::indexparam::SEARCH_K] = 100; + return json; + }; + + auto hnsw_gen = [&base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::HNSW_M] = 128; + json[knowhere::indexparam::EFCONSTRUCTION] = 200; + json[knowhere::indexparam::EF] = 32; + return json; + }; + + auto ivfflat_gen = [&base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 4; + return json; + }; + + auto bin_ivfflat_gen = [&base_bin_gen]() { + knowhere::Json json = base_bin_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 4; + return json; + }; + + auto flat_gen = base_gen; + auto bin_flat_gen = base_bin_gen; + + auto load_raw_data = [](knowhere::Index& index, const knowhere::DataSet& dataset, + const knowhere::Json& conf) { + auto rows = dataset.GetRows(); + auto dim = dataset.GetDim(); + auto p_data = dataset.GetTensor(); + knowhere::BinarySet bs; + auto res = index.Serialize(bs); + REQUIRE(res == knowhere::Status::success); + knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)p_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append("RAW_DATA", bptr); + res = index.Deserialize(bs); + REQUIRE(res == knowhere::Status::success); + }; + + SECTION("Test binary index") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, bin_flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + auto train_ds = GenBinDataSet(nb, dim, seed); + auto ids_ds = GenIdsDataSet(nb, dim); + REQUIRE(idx.Type() == name); + auto res = idx.Build(*train_ds, json); + REQUIRE(res == knowhere::Status::success); + auto results = idx.GetVectorByIds(*ids_ds, json); + REQUIRE(results.has_value()); + auto xb = (uint8_t*)train_ds->GetTensor(); + auto data = (uint8_t*)results.value()->GetTensor(); + for (int i = 0; i < nb * dim / 8; ++i) { + CHECK(data[i] == xb[i]); + } + } + + SECTION("Test float index") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_ANNOY, annoy_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + auto train_ds = GenDataSet(nb, dim, seed); + auto ids_ds = GenIdsDataSet(nb, dim); + REQUIRE(idx.Type() == name); + auto res = idx.Build(*train_ds, json); + REQUIRE(res == knowhere::Status::success); + if (name == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + load_raw_data(idx, *train_ds, json); + } + auto results = idx.GetVectorByIds(*ids_ds, json); + REQUIRE(results.has_value()); + auto xb = (float*)train_ds->GetTensor(); + auto data = (float*)results.value()->GetTensor(); + for (int i = 0; i < nb * dim; ++i) { + CHECK(data[i] == xb[i]); + } + } +} diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 199d2d869..f72e694f8 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -28,6 +28,7 @@ struct DisPairLess { } }; }; // namespace + inline std::unique_ptr GenDataSet(int rows, int dim, int seed = 42) { std::mt19937 rng(seed); @@ -42,6 +43,32 @@ GenDataSet(int rows, int dim, int seed = 42) { return ds; } +inline std::unique_ptr +GenBinDataSet(int rows, int dim, int seed = 42) { + std::mt19937 rng(seed); + std::uniform_int_distribution<> distrib(0.0, 100.0); + + auto ds = std::make_unique(); + ds->SetRows(rows); + ds->SetDim(dim); + int uint8_num = dim / 8; + uint8_t* ts = new uint8_t[rows * uint8_num]; + for (int i = 0; i < rows * uint8_num; ++i) ts[i] = (uint8_t)distrib(rng); + ds->SetTensor(ts); + return ds; +} + +inline std::unique_ptr +GenIdsDataSet(int rows, int dim) { + auto ds = std::make_unique(); + ds->SetRows(rows); + ds->SetDim(dim); + int64_t* ids = new int64_t[rows]; + for (int i = 0; i < rows; ++i) ids[i] = i; + ds->SetIds(ids); + return ds; +} + inline std::unique_ptr GetKNNGroundTruth(const knowhere::DataSet& base, const knowhere::DataSet& query, const std::string& metric, const int topk, const knowhere::BitsetView bitset = nullptr) {