Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Support for specifying RDMA devices when multiple RDMA devices are present. #2006

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
--disable-perf
--disable-efa
--disable-mrail
--with-cuda=no
--enable-verbs > /dev/null
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/libfabric
)
Expand All @@ -719,6 +720,7 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
--disable-perf
--disable-efa
--disable-mrail
--with-cuda=no
--disable-verbs > /dev/null
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/thirdparty/libfabric
)
Expand Down
51 changes: 34 additions & 17 deletions src/client/rpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,18 @@ Status RPCClient::Connect(const std::string& rpc_endpoint) {
Status RPCClient::Connect(const std::string& rpc_endpoint,
std::string const& username,
std::string const& password,
const std::string& rdma_endpoint) {
const std::string& rdma_endpoint,
std::string src_rdma_ednpoint) {
return this->Connect(rpc_endpoint, RootSessionID(), username, password,
rdma_endpoint);
rdma_endpoint, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& rpc_endpoint,
const SessionID session_id,
std::string const& username,
std::string const& password,
const std::string& rdma_endpoint) {
const std::string& rdma_endpoint,
std::string src_rdma_ednpoint) {
size_t pos = rpc_endpoint.find(":");
std::string host, port;
if (pos == std::string::npos) {
Expand All @@ -125,28 +127,32 @@ Status RPCClient::Connect(const std::string& rpc_endpoint,

return this->Connect(host, static_cast<uint32_t>(std::stoul(port)),
session_id, username, password, rdma_host,
static_cast<uint32_t>(std::stoul(rdma_port)));
static_cast<uint32_t>(std::stoul(rdma_port)),
src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
return this->Connect(host, port, RootSessionID(), "", "", rdma_host,
rdma_port);
rdma_port, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
std::string const& username,
std::string const& password,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
return this->Connect(host, port, RootSessionID(), username, password,
rdma_host, rdma_port);
rdma_host, rdma_port, src_rdma_ednpoint);
}

Status RPCClient::Connect(const std::string& host, uint32_t port,
const SessionID session_id,
std::string const& username,
std::string const& password,
const std::string& rdma_host, uint32_t rdma_port) {
const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_ednpoint) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);
std::string rpc_endpoint = host + ":" + std::to_string(port);
RETURN_ON_ASSERT(!connected_ || rpc_endpoint == rpc_endpoint_);
Expand Down Expand Up @@ -183,7 +189,8 @@ Status RPCClient::Connect(const std::string& host, uint32_t port,
instance_id_ = UnspecifiedInstanceID() - 1;

if (rdma_host.length() > 0) {
Status status = ConnectRDMA(rdma_host, rdma_port);
src_rdma_endpoint_ = src_rdma_ednpoint;
Status status = ConnectRDMA(rdma_host, rdma_port, src_rdma_ednpoint);
if (status.ok()) {
rdma_endpoint_ = rdma_host + ":" + std::to_string(rdma_port);
std::cout << "Connected to RPC server: " << rpc_endpoint
Expand All @@ -192,33 +199,38 @@ Status RPCClient::Connect(const std::string& host, uint32_t port,
} else {
std::cout << "Connect RDMA server failed! Fall back to RPC mode. Error:"
<< status.message() << std::endl;
std::cout << "Failed src_rdma_ednpoint: " << src_rdma_ednpoint
<< std::endl;
}
}

return Status::OK();
}

Status RPCClient::ConnectRDMA(const std::string& rdma_host,
uint32_t rdma_port) {
Status RPCClient::ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_endpoint) {
if (this->rdma_connected_) {
return Status::OK();
}

RETURN_ON_ERROR(RDMAClientCreator::Create(this->rdma_client_, rdma_host,
static_cast<int>(rdma_port)));
static_cast<int>(rdma_port),
src_rdma_endpoint));

int retry = 0;
do {
if (this->rdma_client_->Connect().ok()) {
Status status = this->rdma_client_->Connect();
if (status.ok()) {
break;
}
if (retry == 10) {
return Status::Invalid("Failed to connect to RDMA server.");
}
retry++;
usleep(300 * 1000);
std::cout << "Connect rdma server failed! retry: " << retry << " times."
<< std::endl;
std::cout << "Connect rdma server failed! Error:" + status.message() +
"retry: "
<< retry << " times." << std::endl;
} while (true);
this->rdma_connected_ = true;
return Status::OK();
Expand Down Expand Up @@ -272,6 +284,9 @@ Status RPCClient::RDMAReleaseMemInfo(RegisterMemInfo& remote_info) {

Status RPCClient::StopRDMA() {
if (!rdma_connected_) {
RETURN_ON_ERROR(
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
rdma_endpoint_, src_rdma_endpoint_)));
return Status::OK();
}
rdma_connected_ = false;
Expand All @@ -285,7 +300,9 @@ Status RPCClient::StopRDMA() {

RETURN_ON_ERROR(rdma_client_->Stop());
RETURN_ON_ERROR(rdma_client_->Close());
RETURN_ON_ERROR(RDMAClientCreator::Release(rdma_endpoint_));
RETURN_ON_ERROR(
RDMAClientCreator::Release(RDMAClientCreator::buildConnectionKey(
rdma_endpoint_, src_rdma_endpoint_)));

return Status::OK();
}
Expand Down
19 changes: 13 additions & 6 deletions src/client/rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class RPCClient final : public ClientBase {
*/
Status Connect(const std::string& rpc_endpoint, std::string const& username,
std::string const& password,
const std::string& rdma_endpoint = "");
const std::string& rdma_endpoint = "",
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP endpoint `rpc_endpoint`.
Expand All @@ -104,7 +105,8 @@ class RPCClient final : public ClientBase {
Status Connect(const std::string& rpc_endpoint, const SessionID session_id,
std::string const& username = "",
std::string const& password = "",
const std::string& rdma_endpoint = "");
const std::string& rdma_endpoint = "",
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
Expand All @@ -117,7 +119,8 @@ class RPCClient final : public ClientBase {
* @return Status that indicates whether the connect has succeeded.
*/
Status Connect(const std::string& host, uint32_t port,
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
Expand All @@ -131,7 +134,8 @@ class RPCClient final : public ClientBase {
*/
Status Connect(const std::string& host, uint32_t port,
std::string const& username, std::string const& password,
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Connect to vineyardd using the given TCP `host` and `port`.
Expand All @@ -147,7 +151,8 @@ class RPCClient final : public ClientBase {
Status Connect(const std::string& host, uint32_t port,
const SessionID session_id, std::string const& username = "",
std::string const& password = "",
const std::string& rdma_host = "", uint32_t rdma_port = -1);
const std::string& rdma_host = "", uint32_t rdma_port = -1,
std::string src_rdma_ednpoint = "");

/**
* @brief Create a new client using self endpoint.
Expand Down Expand Up @@ -436,7 +441,8 @@ class RPCClient final : public ClientBase {
const std::string rdma_endpoint() { return rdma_endpoint_; }

private:
Status ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port);
Status ConnectRDMA(const std::string& rdma_host, uint32_t rdma_port,
std::string src_rdma_endpoint = "");

Status StopRDMA();

Expand Down Expand Up @@ -479,6 +485,7 @@ class RPCClient final : public ClientBase {
std::string rdma_endpoint_;
std::shared_ptr<RDMAClient> rdma_client_;
mutable bool rdma_connected_ = false;
std::string src_rdma_endpoint_ = "";

friend class Client;
};
Expand Down
23 changes: 19 additions & 4 deletions src/common/rdma/rdma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Status IRDMA::RegisterMemory(fid_mr** mr, fid_domain* domain, void* address,
return Status::IOError("Failed to register memory region:" +
std::to_string(ret));
}
CHECK_ERROR(!ret, "Failed to register memory region:" + std::to_string(ret));
CHECK_ERROR(ret, "Failed to register memory region:" + std::to_string(ret));

mr_desc = fi_mr_desc(*mr);

Expand Down Expand Up @@ -177,10 +177,25 @@ int IRDMA::GetCompletion(fid_cq* cq, int timeout, void** context) {
return ret < 0 ? ret : 0;
}

void IRDMA::FreeInfo(fi_info* info) {
if (info) {
fi_freeinfo(info);
void IRDMA::FreeInfo(fi_info* info, bool is_hints) {
if (!info) {
return;
}

if (is_hints) {
if (info->src_addr) {
free(info->src_addr);
info->src_addr = nullptr;
info->src_addrlen = 0;
}
if (info->dest_addr) {
free(info->dest_addr);
info->dest_addr = nullptr;
info->dest_addrlen = 0;
}
}

fi_freeinfo(info);
}

} // namespace vineyard
Expand Down
2 changes: 1 addition & 1 deletion src/common/rdma/rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class IRDMA {

static int GetCompletion(fid_cq* cq, int timeout, void** context);

static void FreeInfo(fi_info* info);
static void FreeInfo(fi_info* info, bool is_hints);

template <typename FIDType>
static Status CloseResource(FIDType* res, const char* resource_name) {
Expand Down
Loading
Loading