From 9e453a331ef60194f43d97cf31c2975da6e69979 Mon Sep 17 00:00:00 2001 From: "Yusheng.Ma" Date: Tue, 14 Mar 2023 13:14:40 +0000 Subject: [PATCH] adapt for raft ivf Signed-off-by: Yusheng.Ma --- benchmark/CMakeLists.txt | 7 +- .../hdf5/benchmark_knowhere_float_qps.cpp | 84 +++++- benchmark/prepare.sh | 13 - include/knowhere/comp/index_param.h | 2 + include/knowhere/gpu/gpu_res_mgr.h | 20 -- src/index/ivf_raft/ivf_raft.cuh | 244 ++++++++++++++++-- 6 files changed, 304 insertions(+), 66 deletions(-) diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 72c6d3170..a0eac07bf 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -12,15 +12,14 @@ include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR}/include) -include_directories(/usr/local/hdf5/include) -link_directories(/usr/local/hdf5/lib) - +find_package(HDF5 REQUIRED) +include_directories(${HDF5_INCLUDE_DIRS}) set(unittest_libs gtest gmock gtest_main gmock_main) set(depend_libs knowhere - hdf5 + ${HDF5_LIBRARIES} ${OpenBLAS_LIBRARIES} ${LAPACK_LIBRARIES} ) diff --git a/benchmark/hdf5/benchmark_knowhere_float_qps.cpp b/benchmark/hdf5/benchmark_knowhere_float_qps.cpp index f3e02e778..fc8cec524 100644 --- a/benchmark/hdf5/benchmark_knowhere_float_qps.cpp +++ b/benchmark/hdf5/benchmark_knowhere_float_qps.cpp @@ -18,8 +18,9 @@ #include "knowhere/comp/knowhere_config.h" #include "knowhere/dataset.h" -const int32_t GPU_DEVICE_ID = 0; -const int32_t CLIENT_NUM = 4; +constexpr int32_t GPU_DEVICE_ID = 0; +constexpr int32_t CLIENT_NUM = 1; +constexpr int32_t THREAD_NUM = 8; class Benchmark_knowhere_float_qps : public Benchmark_knowhere, public ::testing::Test { public: @@ -65,6 +66,49 @@ class Benchmark_knowhere_float_qps : public Benchmark_knowhere, public ::testing } } + void + test_raft_ivf(const knowhere::Json& cfg) { + auto conf = cfg; + auto nlist = conf[knowhere::indexparam::NLIST].get(); + + auto find_smallest_nprobe = [&](float expected_recall) -> int32_t { + int32_t golden_nq = 10000, golden_topk = 100; + int32_t nprobe = 1; + float recall; + while (nprobe <= NLIST_) { + conf[knowhere::indexparam::NPROBE] = nprobe; + conf[knowhere::meta::TOPK] = golden_topk; + auto ds_ptr = knowhere::GenDataSet(golden_nq, dim_, xq_); + + auto result = index_.Search(*ds_ptr, conf, nullptr); + recall = CalcRecall(result.value()->GetIds(), golden_nq, golden_topk); + printf("\n[%0.3f s] iterate IVF param for recall %.4f: nlist=%d, nprobe=%d, k=%d, R@=%.4f\n", + get_time_diff(), expected_recall, nlist, nprobe, golden_topk, recall); + if (recall >= expected_recall) { + break; + } + nprobe *= 2; + } + return std::min(nprobe, NLIST_); + }; + + for (auto expected_recall : EXPECTED_RECALLs_) { + auto nprobe = find_smallest_nprobe(expected_recall); + conf[knowhere::indexparam::NPROBE] = nprobe; + conf[knowhere::meta::TOPK] = topk_; + + printf("\n[%0.3f s] %s | %s | nlist=%d, nprobe=%d, k=%d, R@=%.4f\n", get_time_diff(), + ann_test_name_.c_str(), index_type_.c_str(), nlist, nprobe, topk_, expected_recall); + printf("================================================================================\n"); + CALC_TIME_SPAN(task(conf, CLIENT_NUM * THREAD_NUM, nq_)); + printf(" client_num = %d, elapse = %6.3fs, QPS = %.3f\n", CLIENT_NUM, t_diff, + nq_ * CLIENT_NUM * THREAD_NUM / t_diff); + std::fflush(stdout); + printf("================================================================================\n"); + printf("[%.3f s] Test '%s/%s' done\n\n", get_time_diff(), ann_test_name_.c_str(), index_type_.c_str()); + } + } + void test_hnsw(const knowhere::Json& cfg) { auto conf = cfg; @@ -120,7 +164,7 @@ class Benchmark_knowhere_float_qps : public Benchmark_knowhere, public ::testing std::vector thread_vector(worker_num); for (int32_t i = 0; i < worker_num; i++) { - thread_vector[i] = std::thread(worker, i, nq); + thread_vector[i] = std::thread(worker, i % CLIENT_NUM, nq); } for (int32_t i = 0; i < worker_num; i++) { thread_vector[i].join(); @@ -140,7 +184,7 @@ class Benchmark_knowhere_float_qps : public Benchmark_knowhere, public ::testing cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AUTO); #ifdef USE_CUDA - knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, CLIENT_NUM); + // knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, CLIENT_NUM); cfg_[knowhere::meta::DEVICE_ID] = GPU_DEVICE_ID; #endif } @@ -227,3 +271,35 @@ TEST_F(Benchmark_knowhere_float_qps, TEST_HNSW) { binary_set_.clear(); test_hnsw(conf); } + +TEST_F(Benchmark_knowhere_float_qps, TEST_RAFT_IVF_FLAT) { + index_type_ = knowhere::IndexEnum::INDEX_RAFT_IVFFLAT; + + knowhere::Json conf = cfg_; + conf[knowhere::indexparam::NLIST] = NLIST_; + + std::string index_file_name = get_index_name({NLIST_}); + + for (int i = 0; i < CLIENT_NUM; i++) { + indices_.emplace_back(create_index(index_file_name, conf)); + indices_.back().Deserialize(binary_set_); + } + binary_set_.clear(); + test_raft_ivf(conf); +} + +TEST_F(Benchmark_knowhere_float_qps, TEST_RAFT_IVF_PQ) { + index_type_ = knowhere::IndexEnum::INDEX_RAFT_IVFPQ; + + knowhere::Json conf = cfg_; + conf[knowhere::indexparam::NLIST] = NLIST_; + + std::string index_file_name = get_index_name({NLIST_}); + + for (int i = 0; i < CLIENT_NUM; i++) { + indices_.emplace_back(create_index(index_file_name, conf)); + indices_.back().Deserialize(binary_set_); + } + binary_set_.clear(); + test_raft_ivf(conf); +} diff --git a/benchmark/prepare.sh b/benchmark/prepare.sh index 69f6212ff..c79f9c6fe 100755 --- a/benchmark/prepare.sh +++ b/benchmark/prepare.sh @@ -3,19 +3,6 @@ SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" ROOT="$(dirname "$(dirname "$SCRIPTPATH")")" -HDF5_DIR=$ROOT"/hdf5-hdf5-1_13_2" -wget https://github.com/HDFGroup/hdf5/archive/refs/tags/hdf5-1_13_2.tar.gz -tar xvfz hdf5-1_13_2.tar.gz -rm hdf5-1_13_2.tar.gz -cd $HDF5_DIR -./configure --prefix=/usr/local/hdf5 --enable-fortran -make -j8 -make install -cd $ROOT -rm -r hdf5-hdf5-1_13_2 - -./build.sh -u -t Release -b - SIFT_FILE=$ROOT"/output/unittest/sift-128-euclidean.hdf5" wget -P $SIFT_FILE http://ann-benchmarks.com/sift-128-euclidean.hdf5 diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index d52bcb796..b65bf2b06 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -38,6 +38,8 @@ constexpr const char* INDEX_ANNOY = "ANNOY"; constexpr const char* INDEX_HNSW = "HNSW"; constexpr const char* INDEX_DISKANN = "DISKANN"; +constexpr const char* INDEX_RAFT_IVFFLAT = "RAFT_IVF_FLAT"; +constexpr const char* INDEX_RAFT_IVFPQ = "RAFT_IVF_PQ"; } // namespace IndexEnum diff --git a/include/knowhere/gpu/gpu_res_mgr.h b/include/knowhere/gpu/gpu_res_mgr.h index 155167bfb..db5363d87 100644 --- a/include/knowhere/gpu/gpu_res_mgr.h +++ b/include/knowhere/gpu/gpu_res_mgr.h @@ -82,18 +82,6 @@ class GPUResMgr { LOG_KNOWHERE_DEBUG_ << "InitDevice gpu_id " << gpu_id_ << ", resource count " << gpu_params_.res_num_ << ", tmp_mem_sz " << gpu_params_.tmp_mem_sz_ / MB << "MB, pin_mem_sz " << gpu_params_.pin_mem_sz_ / MB << "MB"; -#ifdef KNOWHERE_WITH_RAFT - if (gpu_id >= std::numeric_limits::min() && gpu_id <= std::numeric_limits::max()) { - auto rmm_id = rmm::cuda_device_id{int(gpu_id)}; - rmm_memory_resources_.push_back( - std::make_unique>( - rmm::mr::get_per_device_resource(rmm_id))); - rmm::mr::set_per_device_resource(rmm_id, rmm_memory_resources_.back().get()); - } else { - LOG_KNOWHERE_WARNING_ << "Could not init pool memory resource on GPU " << gpu_id_ - << ". ID is outside expected range."; - } -#endif } void @@ -125,11 +113,6 @@ class GPUResMgr { res_bq_.Take(); } init_ = false; -#ifdef KNOWHERE_WITH_RAFT - for (auto&& rmm_res : rmm_memory_resources_) { - rmm_res.release(); - } -#endif } ResPtr @@ -156,9 +139,6 @@ class GPUResMgr { int64_t gpu_id_ = 0; GPUParams gpu_params_; ResBQ res_bq_; -#ifdef KNOWHERE_WITH_RAFT - std::vector>> rmm_memory_resources_; -#endif }; class ResScope { diff --git a/src/index/ivf_raft/ivf_raft.cuh b/src/index/ivf_raft/ivf_raft.cuh index 9c9f4703e..fc62bfc62 100644 --- a/src/index/ivf_raft/ivf_raft.cuh +++ b/src/index/ivf_raft/ivf_raft.cuh @@ -37,6 +37,58 @@ namespace knowhere { +namespace raft_res_pool { + +struct context { + static constexpr std::size_t default_size{16}; + context() : stream_(default_size), up_mr_(), mr_(&up_mr_) { + rmm::mr::set_current_device_resource(&mr_); + for (size_t i = 0; i < default_size; ++i) { + resources_.emplace_back(stream_[i], nullptr, &mr_); + } + } + ~context() = default; + + context(context&&) = delete; + context(context const&) = delete; + context& + operator=(context&&) = delete; + context& + operator=(context const&) = delete; + std::vector stream_; + rmm::mr::cuda_memory_resource up_mr_; + rmm::mr::pool_memory_resource mr_; + std::vector resources_; + mutable std::atomic_size_t next_{}; +}; + +class resource { + public: + static resource& + instance() { + static resource res; + return res; + } + raft::device_resources* + get_raft_res(rmm::cuda_device_id::value_type device_id) { + std::lock_guard lock(mtx_); + auto it = map_.find(device_id); + if (it == map_.end()) { + map_[device_id] = std::make_unique(); + it = map_.find(device_id); + } + it->second->next_++; + return &(it->second->resources_[it->second->next_ % context::default_size]); + } + + private: + resource(){}; + std::map> map_; + mutable std::mutex mtx_; +}; + +}; // namespace raft_res_pool + namespace detail { using raft_ivf_flat_index = raft::neighbors::ivf_flat::index; using raft_ivf_pq_index = raft::neighbors::ivf_pq::index; @@ -100,14 +152,16 @@ auto static constexpr const CUDA_R_8F_E5M2 = "CUDA_R_8F_E5M2"; inline expected str_to_cuda_dtype(std::string const& str) { static const std::unordered_map name_map = { - {cuda_type::CUDA_R_16F, CUDA_R_16F}, {cuda_type::CUDA_C_16F, CUDA_C_16F}, - {cuda_type::CUDA_R_16BF, CUDA_R_16BF}, {cuda_type::CUDA_C_16BF, CUDA_C_16BF}, - {cuda_type::CUDA_R_32F, CUDA_R_32F}, {cuda_type::CUDA_C_32F, CUDA_C_32F}, - {cuda_type::CUDA_R_64F, CUDA_R_64F}, {cuda_type::CUDA_C_64F, CUDA_C_64F}, - {cuda_type::CUDA_R_8I, CUDA_R_8I}, {cuda_type::CUDA_C_8I, CUDA_C_8I}, - {cuda_type::CUDA_R_8U, CUDA_R_8U}, {cuda_type::CUDA_C_8U, CUDA_C_8U}, - {cuda_type::CUDA_R_32I, CUDA_R_32I}, {cuda_type::CUDA_C_32I, CUDA_C_32I}, - {cuda_type::CUDA_R_8F_E4M3, CUDA_R_8F_E4M3}, {cuda_type::CUDA_R_8F_E5M2, CUDA_R_8F_E5M2}, + {cuda_type::CUDA_R_16F, CUDA_R_16F}, {cuda_type::CUDA_C_16F, CUDA_C_16F}, + {cuda_type::CUDA_R_16BF, CUDA_R_16BF}, {cuda_type::CUDA_C_16BF, CUDA_C_16BF}, + {cuda_type::CUDA_R_32F, CUDA_R_32F}, {cuda_type::CUDA_C_32F, CUDA_C_32F}, + {cuda_type::CUDA_R_64F, CUDA_R_64F}, {cuda_type::CUDA_C_64F, CUDA_C_64F}, + {cuda_type::CUDA_R_8I, CUDA_R_8I}, {cuda_type::CUDA_C_8I, CUDA_C_8I}, + {cuda_type::CUDA_R_8U, CUDA_R_8U}, {cuda_type::CUDA_C_8U, CUDA_C_8U}, + {cuda_type::CUDA_R_32I, CUDA_R_32I}, {cuda_type::CUDA_C_32I, CUDA_C_32I}, + // not support, when we use cuda 11.6 + //{cuda_type::CUDA_R_8F_E4M3, CUDA_R_8F_E4M3}, {cuda_type::CUDA_R_8F_E5M2, CUDA_R_8F_E5M2}, + }; auto it = name_map.find(str); @@ -133,7 +187,7 @@ struct KnowhereConfigType { template class RaftIvfIndexNode : public IndexNode { public: - RaftIvfIndexNode(const Object& object) : devs_{}, res_{std::make_unique()}, gpu_index_{} { + RaftIvfIndexNode(const Object& object) : devs_{}, gpu_index_{} { } virtual Status @@ -158,15 +212,14 @@ class RaftIvfIndexNode : public IndexNode { return metric.error(); } if (metric.value() != raft::distance::DistanceType::L2Expanded && - metric.value() != raft::distance::DistanceType::L2Unexpanded && metric.value() != raft::distance::DistanceType::InnerProduct) { LOG_KNOWHERE_WARNING_ << "selected metric not supported in RAFT IVF indexes: " << ivf_raft_cfg.metric_type; return Status::invalid_metric_type; } - + devs_.insert(devs_.begin(), ivf_raft_cfg.gpu_ids.begin(), ivf_raft_cfg.gpu_ids.end()); auto scoped_device = detail::device_setter{*ivf_raft_cfg.gpu_ids.begin()}; - res_ = std::make_unique(); + auto res_ = raft_res_pool::resource::instance().get_raft_res(devs_[0]); auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); auto* data = reinterpret_cast(dataset.GetTensor()); @@ -204,6 +257,10 @@ class RaftIvfIndexNode : public IndexNode { } else { static_assert(std::is_same_v); } + dim_ = dim; + counts_ = rows; + stream.synchronize(); + } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what(); return Status::raft_inner_error; @@ -225,6 +282,9 @@ class RaftIvfIndexNode : public IndexNode { auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); auto* data = reinterpret_cast(dataset.GetTensor()); + auto scoped_device = detail::device_setter{devs_[0]}; + + auto res_ = raft_res_pool::resource::instance().get_raft_res(devs_[0]); auto stream = res_->get_stream(); // TODO(wphicks): Clean up transfer with raft @@ -245,6 +305,8 @@ class RaftIvfIndexNode : public IndexNode { } else { static_assert(std::is_same_v); } + dim_ = dim; + counts_ = rows; } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "RAFT inner error, " << e.what(); result = Status::raft_inner_error; @@ -265,6 +327,8 @@ class RaftIvfIndexNode : public IndexNode { auto dis = std::unique_ptr(new float[output_size]); try { + auto scoped_device = detail::device_setter{devs_[0]}; + auto res_ = raft_res_pool::resource::instance().get_raft_res(devs_[0]); auto stream = res_->get_stream(); // TODO(wphicks): Clean up transfer with raft // buffer objects when available @@ -344,12 +408,149 @@ class RaftIvfIndexNode : public IndexNode { virtual Status Serialize(BinarySet& binset) const override { - return Status::not_implemented; + if (!gpu_index_.has_value()) + return Status::empty_index; + std::stringbuf buf; + + std::ostream os(&buf); + + os.write((char*)(&this->dim_), sizeof(this->dim_)); + os.write((char*)(&this->counts_), sizeof(this->counts_)); + os.write((char*)(&this->devs_[0]), sizeof(this->devs_[0])); + + auto scoped_device = detail::device_setter{devs_[0]}; + + auto res_ = raft_res_pool::resource::instance().get_raft_res(devs_[0]); + + if constexpr (std::is_same_v) { + raft::serialize_scalar(*res_, os, gpu_index_->size()); + raft::serialize_scalar(*res_, os, gpu_index_->dim()); + raft::serialize_scalar(*res_, os, gpu_index_->n_lists()); + raft::serialize_scalar(*res_, os, gpu_index_->metric()); + raft::serialize_scalar(*res_, os, gpu_index_->veclen()); + raft::serialize_scalar(*res_, os, gpu_index_->adaptive_centers()); + raft::serialize_mdspan(*res_, os, gpu_index_->data()); + raft::serialize_mdspan(*res_, os, gpu_index_->indices()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_sizes()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_offsets()); + raft::serialize_mdspan(*res_, os, gpu_index_->centers()); + if (gpu_index_->center_norms()) { + bool has_norms = true; + serialize_scalar(*res_, os, has_norms); + serialize_mdspan(*res_, os, *gpu_index_->center_norms()); + } else { + bool has_norms = false; + serialize_scalar(*res_, os, has_norms); + } + } + if constexpr (std::is_same_v) { + raft::serialize_scalar(*res_, os, gpu_index_->size()); + raft::serialize_scalar(*res_, os, gpu_index_->dim()); + raft::serialize_scalar(*res_, os, gpu_index_->pq_bits()); + raft::serialize_scalar(*res_, os, gpu_index_->pq_dim()); + + raft::serialize_scalar(*res_, os, gpu_index_->metric()); + raft::serialize_scalar(*res_, os, gpu_index_->codebook_kind()); + raft::serialize_scalar(*res_, os, gpu_index_->n_lists()); + raft::serialize_scalar(*res_, os, gpu_index_->n_nonempty_lists()); + + raft::serialize_mdspan(*res_, os, gpu_index_->pq_centers()); + raft::serialize_mdspan(*res_, os, gpu_index_->pq_dataset()); + raft::serialize_mdspan(*res_, os, gpu_index_->indices()); + raft::serialize_mdspan(*res_, os, gpu_index_->rotation_matrix()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_offsets()); + raft::serialize_mdspan(*res_, os, gpu_index_->list_sizes()); + raft::serialize_mdspan(*res_, os, gpu_index_->centers()); + raft::serialize_mdspan(*res_, os, gpu_index_->centers_rot()); + } + + os.flush(); + std::shared_ptr index_binary(new (std::nothrow) uint8_t[buf.str().size()]); + + memcpy(index_binary.get(), buf.str().c_str(), buf.str().size()); + binset.Append(this->Type(), index_binary, buf.str().size()); + return Status::success; } virtual Status Deserialize(const BinarySet& binset) override { - return Status::not_implemented; + std::stringbuf buf; + auto binary = binset.GetByName(this->Type()); + buf.sputn((char*)binary->data.get(), binary->size); + std::istream is(&buf); + + is.read((char*)(&this->dim_), sizeof(this->dim_)); + is.read((char*)(&this->counts_), sizeof(this->counts_)); + this->devs_.resize(1); + is.read((char*)(&this->devs_[0]), sizeof(this->devs_[0])); + auto scoped_device = detail::device_setter{devs_[0]}; + + auto res_ = raft_res_pool::resource::instance().get_raft_res(devs_[0]); + + if constexpr (std::is_same_v) { + auto n_rows = raft::deserialize_scalar(*res_, is); + auto dim = raft::deserialize_scalar(*res_, is); + auto n_lists = raft::deserialize_scalar(*res_, is); + auto metric = raft::deserialize_scalar(*res_, is); + auto veclen = raft::deserialize_scalar(*res_, is); + bool adaptive_centers = raft::deserialize_scalar(*res_, is); + + T index_ = T(*res_, metric, n_lists, adaptive_centers, dim); + + index_.allocate(*res_, n_rows); + raft::deserialize_mdspan(*res_, is, index_.data()); + raft::deserialize_mdspan(*res_, is, index_.indices()); + raft::deserialize_mdspan(*res_, is, index_.list_sizes()); + raft::deserialize_mdspan(*res_, is, index_.list_offsets()); + raft::deserialize_mdspan(*res_, is, index_.centers()); + bool has_norms = raft::deserialize_scalar(*res_, is); + if (has_norms) { + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = *index_.center_norms(); + raft::deserialize_mdspan(*res_, is, center_norms); + } + } + res_->sync_stream(); + is.sync(); + gpu_index_ = T(std::move(index_)); + } + if constexpr (std::is_same_v) { + auto n_rows = raft::deserialize_scalar(*res_, is); + auto dim = raft::deserialize_scalar(*res_, is); + auto pq_bits = raft::deserialize_scalar(*res_, is); + auto pq_dim = raft::deserialize_scalar(*res_, is); + + auto metric = raft::deserialize_scalar(*res_, is); + auto codebook_kind = raft::deserialize_scalar(*res_, is); + auto n_lists = raft::deserialize_scalar(*res_, is); + auto n_nonempty_lists = raft::deserialize_scalar(*res_, is); + + T index_ = T(*res_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, n_nonempty_lists); + index_.allocate(*res_, n_rows); + + raft::deserialize_mdspan(*res_, is, index_.pq_centers()); + raft::deserialize_mdspan(*res_, is, index_.pq_dataset()); + raft::deserialize_mdspan(*res_, is, index_.indices()); + raft::deserialize_mdspan(*res_, is, index_.rotation_matrix()); + raft::deserialize_mdspan(*res_, is, index_.list_offsets()); + raft::deserialize_mdspan(*res_, is, index_.list_sizes()); + raft::deserialize_mdspan(*res_, is, index_.centers()); + raft::deserialize_mdspan(*res_, is, index_.centers_rot()); + res_->sync_stream(); + is.sync(); + gpu_index_ = T(std::move(index_)); + } + // TODO(yusheng.ma):support no raw data mode + /* +#define RAW_DATA "RAW_DATA" + auto data = binset.GetByName(RAW_DATA); + raft_gpu::raw_data_copy(*this->index_, data->data.get(), data->size); + */ + is.sync(); + + return Status::success; } virtual std::unique_ptr @@ -359,11 +560,7 @@ class RaftIvfIndexNode : public IndexNode { virtual int64_t Dim() const override { - auto result = std::int64_t{}; - if (gpu_index_) { - result = gpu_index_->dim(); - } - return result; + return dim_; } virtual int64_t @@ -373,11 +570,7 @@ class RaftIvfIndexNode : public IndexNode { virtual int64_t Count() const override { - auto result = std::int64_t{}; - if (gpu_index_) { - result = gpu_index_->size(); - } - return result; + return counts_; } virtual std::string @@ -392,7 +585,8 @@ class RaftIvfIndexNode : public IndexNode { private: std::vector devs_; - std::unique_ptr res_; + int64_t dim_ = 0; + int64_t counts_ = 0; std::optional gpu_index_; }; } // namespace knowhere