Skip to content

Commit

Permalink
No repeated IPC open (vllm-project#2642)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored Jan 29, 2024
1 parent 55bf954 commit 2d63ea9
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
}
}

using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

class CustomAllreduce {
public:
int rank_;
Expand All @@ -341,7 +346,8 @@ class CustomAllreduce {
// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_;
std::vector<void *> ipc_handles_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_;

/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
Expand All @@ -365,10 +371,7 @@ class CustomAllreduce {
for (int i = 0; i < world_size_; i++) {
Metadata *rank_meta;
if (i != rank_) {
char *handle;
CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i],
cudaIpcMemLazyEnablePeerAccess));
ipc_handles_.push_back(handle);
char *handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_meta = (Metadata *)handle;
} else {
Expand All @@ -378,6 +381,19 @@ class CustomAllreduce {
}
}

char *open_ipc_handle(const void *ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
if (new_handle) {
char *ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}

std::pair<std::vector<uint8_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
Expand Down Expand Up @@ -413,11 +429,7 @@ class CustomAllreduce {
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char *handle;
CUDACHECK(cudaIpcOpenMemHandle(
(void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()),
cudaIpcMemLazyEnablePeerAccess));
ipc_handles_.push_back(handle);
char *handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
Expand Down Expand Up @@ -448,13 +460,8 @@ class CustomAllreduce {
auto &rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char *handle;
CUDACHECK(cudaIpcOpenMemHandle(
(void **)&handle,
*((cudaIpcMemHandle_t *)&handles[j]
[i * sizeof(cudaIpcMemHandle_t)]),
cudaIpcMemLazyEnablePeerAccess));
ipc_handles_.push_back(handle);
char *handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
Expand Down Expand Up @@ -541,7 +548,7 @@ class CustomAllreduce {
}

~CustomAllreduce() {
for (auto ptr : ipc_handles_) {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
Expand Down

0 comments on commit 2d63ea9

Please sign in to comment.