Skip to content

Commit

Permalink
[Disco] Cross-group and p2p send/receive primitives (#17191)
Browse files Browse the repository at this point in the history
This PR introduces the disco CCL primitives for cross-group
and p2p communication.

Specifically, we introduce the send/receive primitives for one group
to send a buffer to its next group, where every worker in the first
group sends the buffer to the corresponding worker in the second
group. The p2p communication refer to the send/receive operations
to/from a target global worker.
  • Loading branch information
MasterJH5574 authored Jul 24, 2024
1 parent 9a07870 commit ae1be53
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 4 deletions.
24 changes: 24 additions & 0 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,30 @@ TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv
* \param buffer The buffer to be received
*/
TVM_DLL void RecvFromWorker0(NDArray buffer);
/*!
* \brief Send a buffer to the corresponding worker in the next group.
* An error is thrown if the worker is already in the last group.
* \param buffer The sending buffer.
*/
TVM_DLL void SendToNextGroup(NDArray buffer);
/*!
* \brief Receive a buffer from the corresponding worker in the previous group.
* An error is thrown if the worker is already in the first group.
* \param buffer The receiving buffer.
*/
TVM_DLL void RecvFromPrevGroup(NDArray buffer);
/*!
* \brief Send a buffer to the target receiver worker (globally across all groups).
* \param buffer The sending buffer.
* \param receiver_id The global receiver worker id.
*/
TVM_DLL void SendToWorker(NDArray buffer, int receiver_id);
/*!
* \brief Receive a buffer from the target sender worker (globally across all groups).
* \param buffer The receiving buffer.
* \param sender_id The global sender worker id.
*/
TVM_DLL void RecvFromWorker(NDArray buffer, int sender_id);
/*! \brief Get the local worker id */
TVM_DLL int WorkerId();
/*!
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,16 @@ def __init__(self, modules: List[Module]):
def __iter__(self):
return iter(self.modules)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Module:
return self.modules[idx]

def __setitem__(self, idx, module):
def __setitem__(self, idx: int, module: Module) -> None:
self.modules[idx] = module

def __len__(self):
return len(self.modules)

def append(self, module):
def append(self, module: Module):
"""Add a module to the end of the ModuleList"""
self.modules.append(module)

Expand Down
16 changes: 16 additions & 0 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv) {

void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer); }

void SendToNextGroup(NDArray buffer) { GetCCLFunc("send_to_next_group")(buffer); }

void RecvFromPrevGroup(NDArray buffer) { GetCCLFunc("recv_from_prev_group")(buffer); }

void SendToWorker(NDArray buffer, int receiver_id) {
GetCCLFunc("send_to_worker")(buffer, receiver_id);
}

void RecvFromWorker(NDArray buffer, int sender_id) {
GetCCLFunc("recv_from_worker")(buffer, sender_id);
}

int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }

void SyncWorker() {
Expand Down Expand Up @@ -136,6 +148,10 @@ TVM_REGISTER_GLOBAL("runtime.disco.broadcast_from_worker0").set_body_typed(Broad
TVM_REGISTER_GLOBAL("runtime.disco.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.gather_to_worker0").set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker0").set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.send_to_next_group").set_body_typed(SendToNextGroup);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_prev_group").set_body_typed(RecvFromPrevGroup);
TVM_REGISTER_GLOBAL("runtime.disco.send_to_worker").set_body_typed(SendToWorker);
TVM_REGISTER_GLOBAL("runtime.disco.recv_from_worker").set_body_typed(RecvFromWorker);
TVM_REGISTER_GLOBAL("runtime.disco.worker_id").set_body_typed([]() -> ShapeTuple {
return ShapeTuple({WorkerId()});
});
Expand Down
86 changes: 86 additions & 0 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,57 @@ void RecvFromWorker0(NDArray buffer) {
NCCL_CALL(ncclGroupEnd());
}

void SendToNextGroup(NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
int receiver_id = worker_id + group_size;
CHECK_LT(receiver_id, ctx->worker->num_workers)
<< "The current group is already the last group and there is no such a next group.";
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
receiver_id, ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}

void RecvFromPrevGroup(NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
int sender_id = worker_id - group_size;
CHECK_GE(sender_id, 0)
<< "The current group is already the first group and there is no such a previous group.";
NCCL_CALL(ncclGroupStart());
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
sender_id, ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}

void SendToWorker(NDArray buffer, int receiver_id) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers)
<< "Invalid receiver id " << receiver_id << ". The world size is "
<< ctx->worker->num_workers;
CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself.";
NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
receiver_id, ctx->global_comm, stream));
}

void RecvFromWorker(NDArray buffer, int sender_id) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
deviceStream_t stream = ctx->GetDefaultStream();
int worker_id = ctx->worker->worker_id;
CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers)
<< "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers;
CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself.";
NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
sender_id, ctx->global_comm, stream));
}

void SyncWorker() {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
ICHECK(ctx->worker != nullptr);
Expand Down Expand Up @@ -284,8 +335,43 @@ TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".gather_to_worker0")
.set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker0")
.set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_next_group")
.set_body_typed(SendToNextGroup);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_prev_group")
.set_body_typed(RecvFromPrevGroup);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".send_to_worker")
.set_body_typed(SendToWorker);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".recv_from_worker")
.set_body_typed(RecvFromWorker);
TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".sync_worker").set_body_typed(SyncWorker);

TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME
".test_send_to_next_group_recv_from_prev_group")
.set_body_typed([](NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4.";
CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2.";
int group_size = ctx->worker->num_workers / ctx->worker->num_groups;
int group_id = ctx->worker->worker_id / group_size;
if (group_id == 0) {
tvm::runtime::nccl::SendToNextGroup(buffer);
} else {
tvm::runtime::nccl::RecvFromPrevGroup(buffer);
}
});

TVM_REGISTER_GLOBAL("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0")
.set_body_typed([](NDArray buffer) {
CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get();
CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4.";
CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2.";
if (ctx->worker->worker_id == 2) {
tvm::runtime::nccl::SendToWorker(buffer, 0);
} else if (ctx->worker->worker_id == 0) {
tvm::runtime::nccl::RecvFromWorker(buffer, 2);
}
});

} // namespace nccl
} // namespace runtime
} // namespace tvm
40 changes: 39 additions & 1 deletion tests/python/disco/test_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import get_global_func
from tvm import relax as rx
from tvm.runtime import disco as di
from tvm.runtime.relax_vm import VirtualMachine
from tvm.script import relax as R
from tvm import get_global_func

_all_session_kinds = [di.ThreadedSession, di.ProcessSession]
_ccl = [get_global_func("runtime.disco.compiled_ccl")()]
Expand Down Expand Up @@ -391,6 +391,44 @@ def test_group_gather(session_kind, ccl, capfd):
), "No warning messages should be generated from disco.Session.gather_to_worker0"


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_send_to_next_group_receive_from_prev_group(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)

array_1 = np.arange(12, dtype="float32").reshape(3, 4)
array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(0, array_1)
d_array.debug_copy_from(1, array_2)
sess.get_global_func("runtime.disco." + ccl + ".test_send_to_next_group_recv_from_prev_group")(
d_array
)

result_1 = d_array.debug_get_from_remote(2).numpy()
result_2 = d_array.debug_get_from_remote(3).numpy()
np.testing.assert_equal(result_1, array_1)
np.testing.assert_equal(result_2, array_2)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_worker2_send_to_worker0(session_kind, ccl):
devices = [0, 1, 2, 3]
sess = session_kind(num_workers=len(devices), num_groups=2)
sess.init_ccl(ccl, *devices)

array = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
d_array = sess.empty((3, 4), "float32")
d_array.debug_copy_from(2, array)
sess.get_global_func("runtime.disco." + ccl + ".test_worker2_sends_to_worker0")(d_array)

result = d_array.debug_get_from_remote(0).numpy()
np.testing.assert_equal(result, array)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("ccl", _ccl)
def test_mlp(session_kind, ccl): # pylint: disable=too-many-locals
Expand Down

0 comments on commit ae1be53

Please sign in to comment.