Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work with IPv6 in the new tracker. #10125

Merged
merged 5 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,28 +436,38 @@ class TCPSocket {
* \brief Accept new connection, returns a new TCP socket for the new connection.
*/
TCPSocket Accept() {
HandleT newfd = accept(Handle(), nullptr, nullptr);
SockAddress addr;
TCPSocket newsock;
auto rc = this->Accept(&newsock, &addr);
SafeColl(rc);
return newsock;
}

[[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
#if defined(_WIN32)
auto interrupt = WSAEINTR;
#else
auto interrupt = EINTR;
#endif
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
system::ThrowAtError("accept");
}
TCPSocket newsock{newfd};
return newsock;
}

[[nodiscard]] Result Accept(TCPSocket *out, SockAddrV4 *addr) {
struct sockaddr_in caddr;
socklen_t caddr_len = sizeof(caddr);
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
if (newfd == InvalidSocket()) {
return system::FailWithCode("Failed to accept.");
if (this->Domain() == SockDomain::kV4) {
struct sockaddr_in caddr;
socklen_t caddr_len = sizeof(caddr);
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
return system::FailWithCode("Failed to accept.");
}
*addr = SockAddress{SockAddrV4{caddr}};
*out = TCPSocket{newfd};
} else {
struct sockaddr_in6 caddr;
socklen_t caddr_len = sizeof(caddr);
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
return system::FailWithCode("Failed to accept.");
}
*addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd};
}
*addr = SockAddrV4{caddr};
*out = TCPSocket{newfd};
return Success();
}

Expand Down
11 changes: 7 additions & 4 deletions src/collective/coll.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "coll.h"

#include <algorithm> // for min, max, copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include <string> // for string
#include <type_traits> // for is_floating_point_v, is_same_v
#include <utility> // for move

Expand Down Expand Up @@ -56,6 +57,8 @@ bool constexpr IsFloatingPointV() {
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
};

std::string msg{"Floating point is not supported for bit wise collective operations."};

auto rc = DispatchDType(type, [&](auto t) {
using T = decltype(t);
switch (op) {
Expand All @@ -70,21 +73,21 @@ bool constexpr IsFloatingPointV() {
}
case Op::kBitwiseAND: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
return Fail(msg);
} else {
return fn(std::bit_and<>{}, t);
}
}
case Op::kBitwiseOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
return Fail(msg);
} else {
return fn(std::bit_or<>{}, t);
}
}
case Op::kBitwiseXOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
return Fail(msg);
} else {
return fn(std::bit_xor<>{}, t);
}
Expand Down
25 changes: 17 additions & 8 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
} << [&] {
return next->NonBlocking(true);
} << [&] {
SockAddrV4 addr;
SockAddress addr;
return listener->Accept(prev.get(), &addr);
} << [&] { return prev->NonBlocking(true); };
} << [&] {
return prev->NonBlocking(true);
};
if (!rc.OK()) {
return rc;
}
Expand Down Expand Up @@ -157,10 +159,13 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
}

for (std::int32_t r = 0; r < comm.Rank(); ++r) {
SockAddrV4 addr;
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
<< [&] { return peer->RecvTimeout(timeout); };
rc = std::move(rc) << [&] {
SockAddress addr;
return listener->Accept(peer.get(), &addr);
} << [&] {
return peer->RecvTimeout(timeout);
};
if (!rc.OK()) {
return rc;
}
Expand All @@ -187,7 +192,9 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
CHECK(rc.OK()) << rc.Report();
if (!rc.OK()) {
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
}
}

#if !defined(XGBOOST_USE_NCCL)
Expand Down Expand Up @@ -247,18 +254,20 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// get ring neighbors
std::string snext;
tracker.Recv(&snext);
if (!rc.OK()) {
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
}
auto jnext = Json::Load(StringView{snext});

proto::PeerInfo ninfo{jnext};

// get the rank of this worker
this->rank_ = BootstrapPrev(ninfo.rank, world);
this->tracker_.rank = rank_;

std::vector<std::shared_ptr<TCPSocket>> workers;
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
if (!rc.OK()) {
return rc;
return Fail("Failed to connect to other workers.", std::move(rc));
}

CHECK(this->channels_.empty());
Expand Down
41 changes: 25 additions & 16 deletions src/collective/tracker.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname
Expand Down Expand Up @@ -27,12 +27,14 @@
#include "tracker.h"
#include "xgboost/collective/result.h" // for Result, Fail, Success
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
#include "xgboost/json.h"
#include "xgboost/json.h" // for Json

namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: n_workers_{static_cast<std::int32_t>(
RequiredArg<Integer const>(config, "n_workers", __func__))},
: sortby_{static_cast<SortBy>(
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
n_workers_{
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
Expand All @@ -56,13 +58,15 @@ Result Tracker::WaitUntilReady() const {
return Success();
}

RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr)
: sock_{std::move(sock)} {
std::int32_t rank{0};
Json jcmd;
std::int32_t port{0};

rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
rc_ = Success() << [&] {
return proto::Magic{}.Verify(&sock_);
} << [&] {
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] {
std::string cmd;
Expand All @@ -83,28 +87,33 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
}
return Success();
} << [&] {
auto host = addr.Addr();
info_ = proto::PeerInfo{host, port, rank};
if (addr.IsV4()) {
auto host = addr.V4().Addr();
info_ = proto::PeerInfo{host, port, rank};
} else {
auto host = addr.V6().Addr();
info_ = proto::PeerInfo{host, port, rank};
}
return Success();
};
}

RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self;
auto rc = collective::GetHostAddress(&self);
auto host = OptionalArg<String>(config, "host", self);
host_ = OptionalArg<String>(config, "host", self);

host_ = host;
listener_ = TCPSocket::Create(SockDomain::kV4);
rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
rc = listener_.Bind(host_, &this->port_);
SafeColl(rc);
listener_.Listen();
}

Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
auto& workers = *p_workers;

std::sort(workers.begin(), workers.end(), WorkerCmp{});
std::sort(workers.begin(), workers.end(), WorkerCmp{this->sortby_});

std::vector<std::thread> bootstrap_threads;
for (std::int32_t r = 0; r < n_workers_; ++r) {
Expand Down Expand Up @@ -224,7 +233,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {

while (state.ShouldContinue()) {
TCPSocket sock;
SockAddrV4 addr;
SockAddress addr;
this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr);
if (!rc.OK()) {
Expand Down Expand Up @@ -291,7 +300,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {

[[nodiscard]] Json RabitTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);

Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
Expand Down
23 changes: 18 additions & 5 deletions src/collective/tracker.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
Expand Down Expand Up @@ -36,6 +36,16 @@ namespace xgboost::collective {
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
protected:
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
// setting, multiple workers can occupy the same host, in which case one should sort
// workers by task. Due to compatibility reason, the task ID is not always available, so
// we use host as the default.
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
} sortby_;

protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
Expand Down Expand Up @@ -76,7 +86,7 @@ class RabitTracker : public Tracker {
Result rc_;

public:
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr);
WorkerProxy(WorkerProxy const& that) = delete;
WorkerProxy(WorkerProxy&& that) = default;
WorkerProxy& operator=(WorkerProxy const&) = delete;
Expand All @@ -96,11 +106,14 @@ class RabitTracker : public Tracker {

void Send(StringView value) { this->sock_.Send(value); }
};
// provide an ordering for workers, this helps us get deterministic topology.
// Provide an ordering for workers, this helps us get deterministic topology.
struct WorkerCmp {
SortBy sortby;
explicit WorkerCmp(SortBy sortby) : sortby{sortby} {}

[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
auto const& lh = lhs.Host();
auto const& rh = rhs.Host();
auto const& lh = sortby == Tracker::SortBy::kHost ? lhs.Host() : lhs.TaskID();
auto const& rh = sortby == Tracker::SortBy::kHost ? rhs.Host() : rhs.TaskID();

if (lh != rh) {
return lh < rh;
Expand Down
Loading