Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Allgather to collective communicator #8765

Merged
merged 2 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions plugin/federated/federated.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ syntax = "proto3";
package xgboost.federated;

service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
Expand All @@ -30,6 +31,17 @@ enum ReduceOperation {
BITWISE_XOR = 5;
}

message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}

message AllgatherReply {
bytes receive_buffer = 1;
}

message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
Expand Down
19 changes: 19 additions & 0 deletions plugin/federated/federated_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ class FederatedClient {
}()},
rank_{rank} {}

std::string Allgather(std::string const &send_buffer) {
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(rank_);
request.set_send_buffer(send_buffer);

AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub_->Allgather(&context, request, &reply);

if (status.ok()) {
return reply.receive_buffer();
} else {
std::cout << status.error_code() << ": " << status.error_message() << '\n';
throw std::runtime_error("Allgather RPC failed");
}
}

std::string Allreduce(std::string const &send_buffer, DataType data_type,
ReduceOperation reduce_operation) {
AllreduceRequest request;
Expand Down
11 changes: 11 additions & 0 deletions plugin/federated/federated_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,17 @@ class FederatedCommunicator : public Communicator {
*/
bool IsFederated() const override { return true; }

/**
* \brief Perform in-place allgather.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param size Number of bytes to be gathered.
*/
void AllGather(void *send_receive_buffer, std::size_t size) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
auto const received = client_->Allgather(send_buffer);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}

/**
* \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data.
Expand Down
7 changes: 7 additions & 0 deletions plugin/federated/federated_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
namespace xgboost {
namespace federated {

grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
AllgatherRequest const* request, AllgatherReply* reply) {
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}

grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
AllreduceRequest const* request, AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
Expand Down
3 changes: 3 additions & 0 deletions plugin/federated/federated_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class FederatedService final : public Federated::Service {
public:
explicit FederatedService(int const world_size) : handler_{world_size} {}

grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;

grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;

Expand Down
11 changes: 11 additions & 0 deletions src/collective/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ class Communicator {
/** @brief Whether the communicator is running in federated mode. */
virtual bool IsFederated() const = 0;

/**
* @brief Gathers data from all processes and distributes it to all processes.
*
* This assumes all ranks have the same size, and input data has been sliced into the
* corresponding position.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
*/
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;

/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
Expand Down
7 changes: 7 additions & 0 deletions src/collective/in_memory_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ class InMemoryCommunicator : public Communicator {
bool IsDistributed() const override { return true; }
bool IsFederated() const override { return false; }

void AllGather(void* in_out, std::size_t size) override {
std::string output;
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
GetRank());
output.copy(static_cast<char*>(in_out), size);
}

void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
auto const bytes = size * GetTypeSize(data_type);
std::string output;
Expand Down
35 changes: 33 additions & 2 deletions src/collective/in_memory_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,32 @@
namespace xgboost {
namespace collective {

/**
* @brief Functor for allgather.
*/
class AllgatherFunctor {
public:
std::string const name{"Allgather"};

AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) {
// Copy the input if this is the first request.
buffer->assign(input, bytes);
} else {
// Splice the input into the common buffer.
auto const per_rank = bytes / world_size_;
auto const index = rank_ * per_rank;
buffer->replace(index, per_rank, input + index, per_rank);
}
}

private:
int world_size_;
int rank_;
};

/**
* @brief Functor for allreduce.
*/
Expand All @@ -17,7 +43,7 @@ class AllreduceFunctor {
std::string const name{"Allreduce"};

AllreduceFunctor(DataType dataType, Operation operation)
: data_type_(dataType), operation_(operation) {}
: data_type_{dataType}, operation_{operation} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (buffer->empty()) {
Expand Down Expand Up @@ -128,7 +154,7 @@ class BroadcastFunctor {
public:
std::string const name{"Broadcast"};

BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {}
BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}

void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
if (rank_ == root_) {
Expand Down Expand Up @@ -167,6 +193,11 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
cv_.notify_all();
}

void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank) {
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
}

void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank, DataType data_type,
Operation op) {
Expand Down
11 changes: 11 additions & 0 deletions src/collective/in_memory_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ class InMemoryHandler {
*/
void Shutdown(uint64_t sequence_number, int rank);

/**
* @brief Perform allgather.
* @param input The input buffer.
* @param bytes Number of bytes in the input buffer.
* @param output The output buffer.
* @param sequence_number Call sequence number.
* @param rank Index of the worker.
*/
void Allgather(char const* input, std::size_t bytes, std::string* output,
std::size_t sequence_number, int rank);

/**
* @brief Perform allreduce.
* @param input The input buffer.
Expand Down
1 change: 1 addition & 0 deletions src/collective/noop_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class NoOpCommunicator : public Communicator {
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
void AllGather(void *, std::size_t) override {}
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return ""; }
Expand Down
6 changes: 6 additions & 0 deletions src/collective/rabit_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class RabitCommunicator : public Communicator {

bool IsFederated() const override { return false; }

void AllGather(void *send_receive_buffer, std::size_t size) override {
auto const per_rank = size / GetWorldSize();
auto const index = per_rank * GetRank();
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
}

void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
Expand Down
12 changes: 12 additions & 0 deletions tests/cpp/collective/test_in_memory_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ class InMemoryCommunicatorTest : public ::testing::Test {
}
}

static void Allgather(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
char buffer[kWorldSize] = {'a', 'b', 'c'};
buffer[rank] = '0' + rank;
comm.AllGather(buffer, kWorldSize);
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(buffer[i], '0' + i);
}
}

static void AllreduceMax(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
Expand Down Expand Up @@ -147,6 +157,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed());
}

TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }

TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }

TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
Expand Down
25 changes: 25 additions & 0 deletions tests/cpp/plugin/test_federated_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ namespace collective {

class FederatedCommunicatorTest : public ::testing::Test {
public:
static void VerifyAllgather(int rank, const std::string& server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllgather(comm, rank);
}

static void VerifyAllreduce(int rank, const std::string& server_address) {
FederatedCommunicator comm{kWorldSize, rank, server_address};
CheckAllreduce(comm);
Expand Down Expand Up @@ -56,6 +61,15 @@ class FederatedCommunicatorTest : public ::testing::Test {
server_thread_->join();
}

static void CheckAllgather(FederatedCommunicator &comm, int rank) {
int buffer[kWorldSize] = {0, 0, 0};
buffer[rank] = rank;
comm.AllGather(buffer, sizeof(buffer));
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(buffer[i], i);
}
}

static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
Expand Down Expand Up @@ -144,6 +158,17 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed());
}

TEST_F(FederatedCommunicatorTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(
std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_));
}
for (auto &thread : threads) {
thread.join();
}
}

TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
Expand Down
31 changes: 29 additions & 2 deletions tests/cpp/plugin/test_federated_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>

#include <ctime>
#include <iostream>
#include <thread>
#include <ctime>

#include "helpers.h"
#include "federated_client.h"
#include "federated_server.h"
#include "helpers.h"

namespace {

Expand All @@ -26,6 +26,11 @@ namespace xgboost {

class FederatedServerTest : public ::testing::Test {
public:
static void VerifyAllgather(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckAllgather(client, rank);
}

static void VerifyAllreduce(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
CheckAllreduce(client);
Expand All @@ -39,6 +44,7 @@ class FederatedServerTest : public ::testing::Test {
static void VerifyMixture(int rank, const std::string& server_address) {
federated::FederatedClient client{server_address, rank};
for (auto i = 0; i < 10; i++) {
CheckAllgather(client, rank);
CheckAllreduce(client);
CheckBroadcast(client, rank);
}
Expand All @@ -62,6 +68,17 @@ class FederatedServerTest : public ::testing::Test {
server_thread_->join();
}

static void CheckAllgather(federated::FederatedClient& client, int rank) {
int data[kWorldSize] = {0, 0, 0};
data[rank] = rank;
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allgather(send_buffer);
auto const* result = reinterpret_cast<int const*>(reply.data());
for (auto i = 0; i < kWorldSize; i++) {
EXPECT_EQ(result[i], i);
}
}

static void CheckAllreduce(federated::FederatedClient& client) {
int data[] = {1, 2, 3, 4, 5};
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
Expand All @@ -88,6 +105,16 @@ class FederatedServerTest : public ::testing::Test {
std::unique_ptr<grpc::Server> server_;
};

TEST_F(FederatedServerTest, Allgather) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_));
}
for (auto& thread : threads) {
thread.join();
}
}

TEST_F(FederatedServerTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
Expand Down