Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ucx linkage explicit and add a new CMake target for it #1032

Merged
merged 9 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,21 @@ target_link_libraries(
raft_nn INTERFACE raft::raft $<TARGET_NAME_IF_EXISTS:raft::raft_nn_lib> nvidia::cutlass::cutlass
)

# ##################################################################################################
# * raft_distributed -------------------------------------------------------------------------------
add_library(raft_distributed INTERFACE)

if(TARGET raft_distributed AND (NOT TARGET raft::distributed))
add_library(raft::distributed ALIAS raft_distributed)
endif()

set_target_properties(raft_distributed PROPERTIES EXPORT_NAME distributed)

rapids_export_package(BUILD ucx raft-distributed-exports)
rapids_export_package(INSTALL ucx raft-distributed-exports)

target_link_libraries(raft_distributed INTERFACE ucx::ucp)

# ##################################################################################################
# * install targets-----------------------------------------------------------
rapids_cmake_install_lib_dir(lib_dir)
Expand Down Expand Up @@ -518,6 +533,13 @@ if(TARGET raft_nn_lib)
)
endif()

install(
TARGETS raft_distributed
DESTINATION ${lib_dir}
COMPONENT distributed
EXPORT raft-distributed-exports
)

install(
DIRECTORY include/raft
COMPONENT raft
Expand All @@ -542,8 +564,8 @@ install(

include("${rapids-cmake-dir}/export/write_dependencies.cmake")

set(raft_components distance nn)
set(raft_install_comp raft raft)
set(raft_components distance nn distributed)
set(raft_install_comp raft raft raft)
if(TARGET raft_distance_lib)
list(APPEND raft_components distance-lib)
list(APPEND raft_install_comp distance)
Expand Down Expand Up @@ -588,11 +610,13 @@ for data science and machine learning.
Optional Components:
- nn
- distance
- distributed

Imported Targets:
- raft::raft
- raft::nn brought in by the `nn` optional component
- raft::distance brought in by the `distance` optional component
- raft::distributed brought in by the `distributed` optional component

]=]
)
Expand Down Expand Up @@ -634,15 +658,32 @@ endif()
# Use `rapids_export` for 22.04 as it will have COMPONENT support
include(cmake/modules/raft_export.cmake)
raft_export(
INSTALL raft COMPONENTS nn distance EXPORT_SET raft-exports GLOBAL_TARGETS raft nn distance
NAMESPACE raft:: DOCUMENTATION doc_string FINAL_CODE_BLOCK code_string
INSTALL raft COMPONENTS nn distance distributed EXPORT_SET raft-exports GLOBAL_TARGETS raft nn
distance distributed NAMESPACE raft:: DOCUMENTATION doc_string FINAL_CODE_BLOCK code_string
)

# ##################################################################################################
# * build export -------------------------------------------------------------
raft_export(
BUILD raft EXPORT_SET raft-exports COMPONENTS nn distance GLOBAL_TARGETS raft raft_distance
raft_nn DOCUMENTATION doc_string NAMESPACE raft:: FINAL_CODE_BLOCK code_string
BUILD
raft
EXPORT_SET
raft-exports
COMPONENTS
nn
distance
distributed
GLOBAL_TARGETS
raft
raft_distance
distributed
raft_nn
DOCUMENTATION
doc_string
NAMESPACE
raft::
FINAL_CODE_BLOCK
code_string
)

# ##################################################################################################
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class std_comms : public comms_iface {
bool restart = false; // resets the timeout when any progress was made

// Causes UCP to progress through the send/recv message queue
while (ucp_handler_.ucp_progress(ucp_worker_) != 0) {
while (ucp_worker_progress(ucp_worker_) != 0) {
restart = true;
}

Expand Down
97 changes: 4 additions & 93 deletions cpp/include/raft/comms/detail/ucp_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include <dlfcn.h>
#include <raft/util/cudart_utils.hpp>
#include <stdio.h>
#include <ucp/api/ucp.h>
Expand All @@ -26,23 +25,6 @@ namespace raft {
namespace comms {
namespace detail {

typedef void (*dlsym_print_info)(ucp_ep_h, FILE*);

typedef void (*dlsym_rec_free)(void*);

typedef int (*dlsym_worker_progress)(ucp_worker_h);

typedef ucs_status_ptr_t (*dlsym_send)(
ucp_ep_h, const void*, size_t, ucp_datatype_t, ucp_tag_t, ucp_send_callback_t);

typedef ucs_status_ptr_t (*dlsym_recv)(ucp_worker_h,
void*,
size_t count,
ucp_datatype_t datatype,
ucp_tag_t,
ucp_tag_t,
ucp_tag_recv_callback_t);

/**
* Standard UCX request object that will be passed
* around asynchronously. This object is really
Expand Down Expand Up @@ -90,96 +72,25 @@ static void recv_callback(void* request, ucs_status_t status, ucp_tag_recv_info_
}

/**
* Helper class for managing `dlopen` state and
* interacting with ucp.
* Helper class for interacting with ucp.
*/
class comms_ucp_handler {
public:
comms_ucp_handler()
{
load_ucp_handle();
load_send_func();
load_recv_func();
load_free_req_func();
load_print_info_func();
load_worker_progress_func();
}

~comms_ucp_handler() { dlclose(ucp_handle); }

private:
void* ucp_handle;

dlsym_print_info print_info_func;
dlsym_rec_free req_free_func;
dlsym_worker_progress worker_progress_func;
dlsym_send send_func;
dlsym_recv recv_func;

void load_ucp_handle()
{
ucp_handle = dlopen("libucp.so", RTLD_LAZY | RTLD_NOLOAD | RTLD_NODELETE);
if (!ucp_handle) {
ucp_handle = dlopen("libucp.so", RTLD_LAZY | RTLD_NODELETE);
ASSERT(ucp_handle, "Cannot open UCX library: %s\n", dlerror());
}
// Reset any potential error
dlerror();
}

void assert_dlerror()
{
char* error = dlerror();
ASSERT(error == NULL, "Error loading function symbol: %s\n", error);
}

void load_send_func()
{
send_func = (dlsym_send)dlsym(ucp_handle, "ucp_tag_send_nb");
assert_dlerror();
}

void load_free_req_func()
{
req_free_func = (dlsym_rec_free)dlsym(ucp_handle, "ucp_request_free");
assert_dlerror();
}

void load_print_info_func()
{
print_info_func = (dlsym_print_info)dlsym(ucp_handle, "ucp_ep_print_info");
assert_dlerror();
}

void load_worker_progress_func()
{
worker_progress_func = (dlsym_worker_progress)dlsym(ucp_handle, "ucp_worker_progress");
assert_dlerror();
}

void load_recv_func()
{
recv_func = (dlsym_recv)dlsym(ucp_handle, "ucp_tag_recv_nb");
assert_dlerror();
}

ucp_tag_t build_message_tag(int rank, int tag) const
{
// keeping the rank in the lower bits enables debugging.
return ((uint32_t)tag << 31) | (uint32_t)rank;
}

public:
int ucp_progress(ucp_worker_h worker) const { return (*(worker_progress_func))(worker); }

/**
* @brief Frees any memory underlying the given ucp request object
*/
void free_ucp_request(ucp_request* request) const
{
if (request->needs_release) {
request->req->completed = 0;
(*(req_free_func))(request->req);
ucp_request_free(request->req);
}
free(request);
}
Expand All @@ -198,7 +109,7 @@ class comms_ucp_handler {
ucp_tag_t ucp_tag = build_message_tag(rank, tag);

ucs_status_ptr_t send_result =
(*(send_func))(ep_ptr, buf, size, ucp_dt_make_contig(1), ucp_tag, send_callback);
ucp_tag_send_nb(ep_ptr, buf, size, ucp_dt_make_contig(1), ucp_tag, send_callback);
struct ucx_context* ucp_req = (struct ucx_context*)send_result;

if (UCS_PTR_IS_ERR(send_result)) {
Expand Down Expand Up @@ -240,7 +151,7 @@ class comms_ucp_handler {
ucp_tag_t ucp_tag = build_message_tag(sender_rank, tag);

ucs_status_ptr_t recv_result =
(*(recv_func))(worker, buf, size, ucp_dt_make_contig(1), ucp_tag, tag_mask, recv_callback);
ucp_tag_recv_nb(worker, buf, size, ucp_dt_make_contig(1), ucp_tag, tag_mask, recv_callback);

struct ucx_context* ucp_req = (struct ucx_context*)recv_result;

Expand Down
5 changes: 3 additions & 2 deletions python/raft-dask/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulti

# If the user requested it we attempt to find RAFT.
if(FIND_RAFT_CPP)
find_package(raft ${raft_dask_version} REQUIRED)
find_package(raft ${raft_dask_version} REQUIRED COMPONENTS distributed)
else()
set(raft_FOUND OFF)
endif()
Expand All @@ -47,7 +47,8 @@ if(NOT raft_FOUND)
enable_language(CUDA)
# Since raft-dask only enables CUDA optionally we need to manually include the file that
# rapids_cuda_init_architectures relies on `project` including.
include("${CMAKE_PROJECT_raft_dask_INCLUDE}")
include("${CMAKE_PROJECT_raft-dask_INCLUDE}")
find_package(ucx REQUIRED)

# raft-dask doesn't actually use raft libraries, it just needs the headers, so we can turn off all
# library compilation and we don't need to install anything here.
Expand Down
3 changes: 1 addition & 2 deletions python/raft-dask/raft_dask/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# =============================================================================

include(${raft-dask-python_SOURCE_DIR}/cmake/thirdparty/get_nccl.cmake)
find_package(ucx REQUIRED)

set(cython_sources comms_utils.pyx nccl.pyx)
set(linked_libraries raft::raft NCCL::NCCL ucx::ucp)
set(linked_libraries raft::raft raft::distributed NCCL::NCCL)
rapids_cython_create_modules(
SOURCE_FILES "${cython_sources}" LINKED_LIBRARIES "${linked_libraries}" CXX
)