Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Provide support for RAFT-based indexes (#712)
Browse files Browse the repository at this point in the history
* Add RAFT to thirdparty libraries

Signed-off-by: William Hicks <[email protected]>

* Add RAFT stubs

Signed-off-by: William Hicks <[email protected]>

* Update CMake config

Signed-off-by: William Hicks <[email protected]>

* Try different CMake configuration for RAFT

Signed-off-by: William Hicks <[email protected]>

* Remove RAFT and RMM submodules

Signed-off-by: William Hicks <[email protected]>

* Correct various RAFT invocation details

Signed-off-by: William Hicks <[email protected]>

* Correct build and metric options

Signed-off-by: William Hicks <[email protected]>

* Split compilation into two units for parallelism

Also correct indices passed during extend

Signed-off-by: William Hicks <[email protected]>

* Add error-handling for RAFT-internal errors

Signed-off-by: William Hicks <[email protected]>

* Correct scope for dataset variables

Signed-off-by: William Hicks <[email protected]>

* Reindent for consistency with existing code

Signed-off-by: William Hicks <[email protected]>

* Update CMakeList and adapt to changes in main

Signed-off-by: William Hicks <[email protected]>

* Remove dependencies provided transitively through RAFT

Signed-off-by: William Hicks <[email protected]>

* Ensure that build does not require RAFT

Signed-off-by: William Hicks <[email protected]>

* Use pool allocator for RAFT indexes

Signed-off-by: William Hicks <[email protected]>

* Pass through all RAFT parameters from config

Signed-off-by: William Hicks <[email protected]>

* Fix style

Signed-off-by: William Hicks <[email protected]>

---------

Signed-off-by: William Hicks <[email protected]>
  • Loading branch information
wphicks authored Mar 13, 2023
1 parent db738c2 commit 11f606c
Show file tree
Hide file tree
Showing 13 changed files with 875 additions and 2 deletions.
36 changes: 34 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,30 @@
# License for the specific language governing permissions and limitations under
# the License

cmake_minimum_required(VERSION 3.2)
cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR)
project(knowhere CXX C)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/")
include(GNUInstallDirs)
include(ExternalProject)
include(cmake/utils/utils.cmake)

knowhere_option(USE_CUDA "Build with CUDA" OFF)
knowhere_option(WITH_UT "Build with UT test" OFF)
knowhere_option(WITH_ASAN "Build with ASAN" OFF)
knowhere_option(WITH_DISKANN "Build with diskann index" OFF)
knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF)
knowhere_option(WITH_BENCHMARK "Build with benchmark" OFF)
knowhere_option(WITH_COVERAGE "Build with coverage" OFF)
knowhere_option(WITH_CCACHE "Build with ccache" ON)
knowhere_option(WITH_PROFILER "Build with profiler" OFF)

if(WITH_RAFT AND NOT USE_CUDA)
message(WARNING "WITH_RAFT requires USE_CUDA. Setting USE_CUDA to ON.")
set(USE_CUDA ON)
endif()

if(KNOWHERE_VERSION)
message(STATUS "Building KNOWHERE version: ${KNOWHERE_VERSION}")
add_definitions(-DKNOWHERE_VERSION=${KNOWHERE_VERSION})
Expand Down Expand Up @@ -68,6 +75,23 @@ include_directories(thirdparty/bitset)
include_directories(thirdparty)

find_package(OpenMP REQUIRED)
if(WITH_RAFT)
add_definitions(-DKNOWHERE_WITH_RAFT)
include(cmake/libs/fetch_rapids.cmake)
include(rapids-cmake)
include(rapids-cpm)
include(rapids-cuda)
include(rapids-export)
include(rapids-find)

rapids_cpm_init()
set(RAPIDS_VERSION 23.02)
include(cmake/libs/libraft.cmake)

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -std=c++17")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -std=c++17")
endif()

if(OPENMP_FOUND)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
Expand All @@ -80,7 +104,7 @@ if(WITH_COVERAGE)
endif()

knowhere_file_glob(GLOB_RECURSE KNOWHERE_SRCS src/common/*.cc src/index/*.cc
src/io/*.cc)
src/io/*.cc src/index/*.cu)

set(KNOWHERE_LINKER_LIBS "")

Expand All @@ -97,6 +121,11 @@ if(NOT USE_CUDA)
list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_GPU_SRCS})
endif()

if(NOT WITH_RAFT)
knowhere_file_glob(GLOB_RECURSE KNOWHERE_RAFT_SRCS src/index/index_raft/*.cc src/index/index_raft/*.cu)
list(REMOVE_ITEM KNOWHERE_SRCS ${KNOWHERE_RAFT_SRCS})
endif()

include_directories(src)
include_directories(include)

Expand All @@ -105,6 +134,9 @@ list(APPEND KNOWHERE_LINKER_LIBS easyloggingpp)

add_library(knowhere SHARED ${KNOWHERE_SRCS})
add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS})
if(WITH_RAFT)
list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::distance raft::nn)
endif()
target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS})
target_include_directories(
knowhere PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include
Expand Down
20 changes: 20 additions & 0 deletions cmake/libs/fetch_rapids.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

set(RAPIDS_VERSION "23.02")

if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake
${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
endif()
include(${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake)
103 changes: 103 additions & 0 deletions cmake/libs/libcutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# =============================================================================
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

function(find_and_configure_cutlass)
set(oneValueArgs VERSION REPOSITORY PINNED_TAG)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

# if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES)
set(CUTLASS_ENABLE_HEADERS_ONLY
ON
CACHE BOOL "Enable only the header library"
)
set(CUTLASS_NAMESPACE
"raft_cutlass"
CACHE STRING "Top level namespace of CUTLASS"
)
set(CUTLASS_ENABLE_CUBLAS
OFF
CACHE BOOL "Disable CUTLASS to build with cuBLAS library."
)

if (CUDA_STATIC_RUNTIME)
set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE)
endif()

rapids_cpm_find(
NvidiaCutlass ${PKG_VERSION}
GLOBAL_TARGETS nvidia::cutlass::cutlass
CPM_ARGS
GIT_REPOSITORY ${PKG_REPOSITORY}
GIT_TAG ${PKG_PINNED_TAG}
GIT_SHALLOW TRUE
OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}"
)

if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass)
add_library(nvidia::cutlass::cutlass ALIAS CUTLASS)
endif()

if(NvidiaCutlass_ADDED)
rapids_export(
BUILD NvidiaCutlass
EXPORT_SET NvidiaCutlass
GLOBAL_TARGETS nvidia::cutlass::cutlass
NAMESPACE nvidia::cutlass::
)
endif()
# endif()

# We generate the cutlass-config files when we built cutlass locally, so always do
# `find_dependency`
rapids_export_package(
BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)
rapids_export_package(
INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)
rapids_export_package(
BUILD NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)
rapids_export_package(
INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass
)

# Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote.
include("${rapids-cmake-dir}/export/find_package_root.cmake")
rapids_export_find_package_root(
INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports
)
rapids_export_find_package_root(
BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports
)
include("${rapids-cmake-dir}/export/find_package_root.cmake")
rapids_export_find_package_root(
INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-nn-exports
)
rapids_export_find_package_root(
BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports
)
endfunction()

if(NOT RAFT_CUTLASS_GIT_TAG)
set(RAFT_CUTLASS_GIT_TAG v2.9.1)
endif()

if(NOT RAFT_CUTLASS_GIT_REPOSITORY)
set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git)
endif()

find_and_configure_cutlass(
VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG}
)
51 changes: 51 additions & 0 deletions cmake/libs/libraft.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# =============================================================================
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

set(RAFT_VERSION "${RAPIDS_VERSION}")
set(RAFT_FORK "rapidsai")
set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}")

function(find_and_configure_raft)
set(oneValueArgs VERSION FORK PINNED_TAG)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

#-----------------------------------------------------
# Invoke CPM find_package()
#-----------------------------------------------------
rapids_cpm_find(raft ${PKG_VERSION}
GLOBAL_TARGETS raft::raft
BUILD_EXPORT_SET faiss-exports
INSTALL_EXPORT_SET faiss-exports
COMPONENTS "distance nn"
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git
GIT_TAG ${PKG_PINNED_TAG}
SOURCE_SUBDIR cpp
OPTIONS
"BUILD_TESTS OFF"
"BUILD_BENCH OFF"
"RAFT_COMPILE_LIBRARIES OFF"
"RAFT_COMPILE_NN_LIBRARY OFF"
"RAFT_USE_FAISS_STATIC OFF" # Turn this on to build FAISS into your binary
"RAFT_ENABLE_NN_DEPENDENCIES OFF"
)
endfunction()

# Change pinned tag here to test a commit in CI
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${RAFT_VERSION}.00
FORK ${RAFT_FORK}
PINNED_TAG ${RAFT_PINNED_TAG}
)
23 changes: 23 additions & 0 deletions cmake/libs/librmm.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#=============================================================================
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#=============================================================================

function(find_and_configure_rmm)
include(${rapids-cmake-dir}/cpm/rmm.cmake)
rapids_cpm_rmm(BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports)
endfunction()

find_and_configure_rmm()
1 change: 1 addition & 0 deletions include/knowhere/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum class Status {
diskann_file_error = 15,
invalid_value_in_json = 16,
arithmetic_overflow = 17,
raft_inner_error = 18,
};

template <typename E>
Expand Down
27 changes: 27 additions & 0 deletions include/knowhere/gpu/gpu_res_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
#pragma once

#include <faiss/gpu/StandardGpuResources.h>
#ifdef KNOWHERE_WITH_RAFT
#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <vector>
#endif

#include <memory>
#include <mutex>
Expand Down Expand Up @@ -75,6 +82,18 @@ class GPUResMgr {
LOG_KNOWHERE_DEBUG_ << "InitDevice gpu_id " << gpu_id_ << ", resource count " << gpu_params_.res_num_
<< ", tmp_mem_sz " << gpu_params_.tmp_mem_sz_ / MB << "MB, pin_mem_sz "
<< gpu_params_.pin_mem_sz_ / MB << "MB";
#ifdef KNOWHERE_WITH_RAFT
if (gpu_id >= std::numeric_limits<int>::min() && gpu_id <= std::numeric_limits<int>::max()) {
auto rmm_id = rmm::cuda_device_id{int(gpu_id)};
rmm_memory_resources_.push_back(
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
rmm::mr::get_per_device_resource(rmm_id)));
rmm::mr::set_per_device_resource(rmm_id, rmm_memory_resources_.back().get());
} else {
LOG_KNOWHERE_WARNING_ << "Could not init pool memory resource on GPU " << gpu_id_
<< ". ID is outside expected range.";
}
#endif
}

void
Expand Down Expand Up @@ -106,6 +125,11 @@ class GPUResMgr {
res_bq_.Take();
}
init_ = false;
#ifdef KNOWHERE_WITH_RAFT
for (auto&& rmm_res : rmm_memory_resources_) {
rmm_res.release();
}
#endif
}

ResPtr
Expand All @@ -132,6 +156,9 @@ class GPUResMgr {
int64_t gpu_id_ = 0;
GPUParams gpu_params_;
ResBQ res_bq_;
#ifdef KNOWHERE_WITH_RAFT
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>> rmm_memory_resources_;
#endif
};

class ResScope {
Expand Down
48 changes: 48 additions & 0 deletions src/common/raft_metric.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef COMMON_RAFT_METRIC_H
#define COMMON_RAFT_METRIC_H

#include <algorithm>
#include <string>
#include <unordered_map>

#include "knowhere/comp/index_param.h"
#include "knowhere/expected.h"
#include "raft/distance/distance_types.hpp"

namespace knowhere {

inline expected<raft::distance::DistanceType, Status>
Str2RaftMetricType(std::string metric) {
static const std::unordered_map<std::string, raft::distance::DistanceType> metric_map = {
{metric::L2, raft::distance::DistanceType::L2Expanded},
{metric::IP, raft::distance::DistanceType::InnerProduct},
{metric::HAMMING, raft::distance::DistanceType::HammingUnexpanded},
{metric::JACCARD, raft::distance::DistanceType::JaccardExpanded},
};

std::transform(metric.begin(), metric.end(), metric.begin(), toupper);
auto it = metric_map.find(metric);
if (it == metric_map.end())
return unexpected(Status::invalid_metric_type);
return it->second;
}

} // namespace knowhere

#endif /* RAFT_METRIC_H */
Loading

0 comments on commit 11f606c

Please sign in to comment.