From 036677a588c0ba209e9623a2a2fabe1afe40a53d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 8 Feb 2023 15:17:16 -0800 Subject: [PATCH 1/2] Add Allgather to collective communicator --- plugin/federated/federated.proto | 12 +++++++ plugin/federated/federated_client.h | 19 ++++++++++ plugin/federated/federated_communicator.h | 11 ++++++ plugin/federated/federated_server.cc | 7 ++++ plugin/federated/federated_server.h | 3 ++ src/collective/communicator.h | 11 ++++++ src/collective/in_memory_communicator.h | 7 ++++ src/collective/in_memory_handler.cc | 35 +++++++++++++++++-- src/collective/in_memory_handler.h | 11 ++++++ src/collective/noop_communicator.h | 1 + src/collective/rabit_communicator.h | 6 ++++ .../collective/test_in_memory_communicator.cc | 12 +++++++ .../cpp/plugin/test_federated_communicator.cc | 25 +++++++++++++ tests/cpp/plugin/test_federated_server.cc | 31 ++++++++++++++-- 14 files changed, 187 insertions(+), 4 deletions(-) diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index 136687109716..d8ef5bd92f43 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -6,6 +6,7 @@ syntax = "proto3"; package xgboost.federated; service Federated { + rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} } @@ -30,6 +31,17 @@ enum ReduceOperation { BITWISE_XOR = 5; } +message AllgatherRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherReply { + bytes receive_buffer = 1; +} + message AllreduceRequest { // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 3d0cdb729bf5..2b4637339199 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -46,6 +46,25 @@ class FederatedClient { }()}, rank_{rank} {} + std::string Allgather(std::string const &send_buffer) { + AllgatherRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(rank_); + request.set_send_buffer(send_buffer); + + AllgatherReply reply; + grpc::ClientContext context; + context.set_wait_for_ready(true); + grpc::Status status = stub_->Allgather(&context, request, &reply); + + if (status.ok()) { + return reply.receive_buffer(); + } else { + std::cout << status.error_code() << ": " << status.error_message() << '\n'; + throw std::runtime_error("Allgather RPC failed"); + } + } + std::string Allreduce(std::string const &send_buffer, DataType data_type, ReduceOperation reduce_operation) { AllreduceRequest request; diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 856ed6aac75f..7acd8a82932d 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -126,6 +126,17 @@ class FederatedCommunicator : public Communicator { */ bool IsFederated() const override { return true; } + /** + * \brief Perform in-place allgather. + * \param send_receive_buffer Buffer for both sending and receiving data. + * \param size Number of bytes to be gathered. + */ + void AllGather(void *send_receive_buffer, std::size_t size) override { + std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); + auto const received = client_->Allgather(send_buffer); + received.copy(reinterpret_cast(send_receive_buffer), size); + } + /** * \brief Perform in-place allreduce. * \param send_receive_buffer Buffer for both sending and receiving data. diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index ec0b451a94e1..b16d347805c7 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -14,6 +14,13 @@ namespace xgboost { namespace federated { +grpc::Status FederatedService::Allgather(grpc::ServerContext* context, + AllgatherRequest const* request, AllgatherReply* reply) { + handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); + return grpc::Status::OK; +} + grpc::Status FederatedService::Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) { handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 3a5abc4c9ea8..7738248ea729 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -14,6 +14,9 @@ class FederatedService final : public Federated::Service { public: explicit FederatedService(int const world_size) : handler_{world_size} {} + grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, + AllgatherReply* reply) override; + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) override; diff --git a/src/collective/communicator.h b/src/collective/communicator.h index 87c1c875e995..885a8d438d6e 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -122,6 +122,17 @@ class Communicator { /** @brief Whether the communicator is running in federated mode. */ virtual bool IsFederated() const = 0; + /** + * @brief Gathers data from all processes and distributes it to all processes. + * + * This assumes all ranks have the same size, and input data has been sliced into the + * corresponding position. + * + * @param send_receive_buffer Buffer storing the data. + * @param size Size of the data in bytes. + */ + virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0; + /** * @brief Combines values from all processes and distributes the result back to all processes. * diff --git a/src/collective/in_memory_communicator.h b/src/collective/in_memory_communicator.h index c1c5d4493059..f41029af1dea 100644 --- a/src/collective/in_memory_communicator.h +++ b/src/collective/in_memory_communicator.h @@ -60,6 +60,13 @@ class InMemoryCommunicator : public Communicator { bool IsDistributed() const override { return true; } bool IsFederated() const override { return false; } + void AllGather(void* in_out, std::size_t size) override { + std::string output; + handler_.Allgather(static_cast(in_out), size, &output, sequence_number_++, + GetRank()); + output.copy(static_cast(in_out), size); + } + void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override { auto const bytes = size * GetTypeSize(data_type); std::string output; diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc index 09518fd964fa..d8a86ec556f4 100644 --- a/src/collective/in_memory_handler.cc +++ b/src/collective/in_memory_handler.cc @@ -9,6 +9,32 @@ namespace xgboost { namespace collective { +/** + * @brief Functor for allgather. + */ +class AllgatherFunctor { + public: + std::string const name{"Allgather"}; + + AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {} + + void operator()(char const* input, std::size_t bytes, std::string* buffer) const { + if (buffer->empty()) { + // Copy the input if this is the first request. + buffer->assign(input, bytes); + } else { + // Splice the input into the common buffer. + auto const per_rank = bytes / world_size_; + auto const index = rank_ * per_rank; + buffer->replace(index, per_rank, input + index, per_rank); + } + } + + private: + int world_size_; + int rank_; +}; + /** * @brief Functor for allreduce. */ @@ -17,7 +43,7 @@ class AllreduceFunctor { std::string const name{"Allreduce"}; AllreduceFunctor(DataType dataType, Operation operation) - : data_type_(dataType), operation_(operation) {} + : data_type_{dataType}, operation_{operation} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { if (buffer->empty()) { @@ -128,7 +154,7 @@ class BroadcastFunctor { public: std::string const name{"Broadcast"}; - BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {} + BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { if (rank_ == root_) { @@ -167,6 +193,11 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) { cv_.notify_all(); } +void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank) { + Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank}); +} + void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output, std::size_t sequence_number, int rank, DataType data_type, Operation op) { diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index 3ab2d9a0bb50..4182c7b3ddb2 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -53,6 +53,17 @@ class InMemoryHandler { */ void Shutdown(uint64_t sequence_number, int rank); + /** + * @brief Perform allgather. + * @param input The input buffer. + * @param bytes Number of bytes in the input buffer. + * @param output The output buffer. + * @param sequence_number Call sequence number. + * @param rank Index of the worker. + */ + void Allgather(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank); + /** * @brief Perform allreduce. * @param input The input buffer. diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index cad6da029530..8a3eae9a5b83 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -17,6 +17,7 @@ class NoOpCommunicator : public Communicator { NoOpCommunicator() : Communicator(1, 0) {} bool IsDistributed() const override { return false; } bool IsFederated() const override { return false; } + void AllGather(void *, std::size_t) {} void AllReduce(void *, std::size_t, DataType, Operation) override {} void Broadcast(void *, std::size_t, int) override {} std::string GetProcessorName() override { return ""; } diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 3c16fde77e43..9b79624a2718 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -55,6 +55,12 @@ class RabitCommunicator : public Communicator { bool IsFederated() const override { return false; } + void AllGather(void *send_receive_buffer, std::size_t size) override { + auto const per_rank = size / GetWorldSize(); + auto const index = per_rank * GetRank(); + rabit::Allgather(static_cast(send_receive_buffer), size, index, per_rank, per_rank); + } + void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, Operation op) override { switch (data_type) { diff --git a/tests/cpp/collective/test_in_memory_communicator.cc b/tests/cpp/collective/test_in_memory_communicator.cc index 1e4f6521ffc8..071005717ef7 100644 --- a/tests/cpp/collective/test_in_memory_communicator.cc +++ b/tests/cpp/collective/test_in_memory_communicator.cc @@ -24,6 +24,16 @@ class InMemoryCommunicatorTest : public ::testing::Test { } } + static void Allgather(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + char buffer[kWorldSize] = {'a', 'b', 'c'}; + buffer[rank] = '0' + rank; + comm.AllGather(buffer, kWorldSize); + for (auto i = 0; i < kWorldSize; i++) { + EXPECT_EQ(buffer[i], '0' + i); + } + } + static void AllreduceMax(int rank) { InMemoryCommunicator comm{kWorldSize, rank}; int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; @@ -147,6 +157,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) { EXPECT_TRUE(comm.IsDistributed()); } +TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); } + TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); } TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); } diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 51d258f02b57..f5d72e5f4972 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -28,6 +28,11 @@ namespace collective { class FederatedCommunicatorTest : public ::testing::Test { public: + static void VerifyAllgather(int rank, const std::string& server_address) { + FederatedCommunicator comm{kWorldSize, rank, server_address}; + CheckAllgather(comm, rank); + } + static void VerifyAllreduce(int rank, const std::string& server_address) { FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckAllreduce(comm); @@ -56,6 +61,15 @@ class FederatedCommunicatorTest : public ::testing::Test { server_thread_->join(); } + static void CheckAllgather(FederatedCommunicator &comm, int rank) { + int buffer[kWorldSize] = {0, 0, 0}; + buffer[rank] = rank; + comm.AllGather(buffer, sizeof(buffer)); + for (auto i = 0; i < kWorldSize; i++) { + EXPECT_EQ(buffer[i], i); + } + } + static void CheckAllreduce(FederatedCommunicator &comm) { int buffer[] = {1, 2, 3, 4, 5}; comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); @@ -144,6 +158,17 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { EXPECT_TRUE(comm.IsDistributed()); } +TEST_F(FederatedCommunicatorTest, Allgather) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back( + std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_)); + } + for (auto &thread : threads) { + thread.join(); + } +} + TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 61828975b71b..fa9c272d2903 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -4,13 +4,13 @@ #include #include +#include #include #include -#include -#include "helpers.h" #include "federated_client.h" #include "federated_server.h" +#include "helpers.h" namespace { @@ -26,6 +26,11 @@ namespace xgboost { class FederatedServerTest : public ::testing::Test { public: + static void VerifyAllgather(int rank, const std::string& server_address) { + federated::FederatedClient client{server_address, rank}; + CheckAllgather(client, rank); + } + static void VerifyAllreduce(int rank, const std::string& server_address) { federated::FederatedClient client{server_address, rank}; CheckAllreduce(client); @@ -39,6 +44,7 @@ class FederatedServerTest : public ::testing::Test { static void VerifyMixture(int rank, const std::string& server_address) { federated::FederatedClient client{server_address, rank}; for (auto i = 0; i < 10; i++) { + CheckAllgather(client, rank); CheckAllreduce(client); CheckBroadcast(client, rank); } @@ -62,6 +68,17 @@ class FederatedServerTest : public ::testing::Test { server_thread_->join(); } + static void CheckAllgather(federated::FederatedClient& client, int rank) { + int data[kWorldSize] = {0, 0, 0}; + data[rank] = rank; + std::string send_buffer(reinterpret_cast(data), sizeof(data)); + auto reply = client.Allgather(send_buffer); + auto const* result = reinterpret_cast(reply.data()); + for (auto i = 0; i < kWorldSize; i++) { + EXPECT_EQ(result[i], i); + } + } + static void CheckAllreduce(federated::FederatedClient& client) { int data[] = {1, 2, 3, 4, 5}; std::string send_buffer(reinterpret_cast(data), sizeof(data)); @@ -88,6 +105,16 @@ class FederatedServerTest : public ::testing::Test { std::unique_ptr server_; }; +TEST_F(FederatedServerTest, Allgather) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_)); + } + for (auto& thread : threads) { + thread.join(); + } +} + TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { From ce50d94ac28808cf2f772a964b752dcd0fc1a3d2 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 8 Feb 2023 15:54:05 -0800 Subject: [PATCH 2/2] fix clang tidy error --- src/collective/noop_communicator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index 8a3eae9a5b83..28a0a1cada4d 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -17,7 +17,7 @@ class NoOpCommunicator : public Communicator { NoOpCommunicator() : Communicator(1, 0) {} bool IsDistributed() const override { return false; } bool IsFederated() const override { return false; } - void AllGather(void *, std::size_t) {} + void AllGather(void *, std::size_t) override {} void AllReduce(void *, std::size_t, DataType, Operation) override {} void Broadcast(void *, std::size_t, int) override {} std::string GetProcessorName() override { return ""; }