Skip to content

Commit

Permalink
train_ddp, process_group: fixes so CUDA works e2e
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Nov 3, 2024
1 parent 5d2e55f commit 808e6fb
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 6 deletions.
94 changes: 90 additions & 4 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
from typing import Type, List, Optional, Callable, Tuple
from datetime import timedelta
import threading

from torch.futures import Future
from torch.distributed import (
Expand All @@ -26,6 +27,11 @@

logger = logging.getLogger(__name__)

# TODO: use non strings which are cheaper
_QUEUE_CLOSE = "queue_close"
_FUTURE_RESULT = "fut_result"
_FUTURE_EXCEPTION = "fut_exception"


def _get(queue: mp.Queue, timeout) -> object:
v = queue.get(timeout=timeout)
Expand Down Expand Up @@ -208,9 +214,17 @@ def getBackendName(self):


class BabyWork(Work):
def __init__(self, tx: mp.Queue, rx: mp.Queue, op_id: int, timeout: float):
def __init__(
self,
pg: "ProcessGroupBaby",
tx: mp.Queue,
rx: mp.Queue,
op_id: int,
timeout: float,
):
super().__init__()

self._pg = pg
self._tx = tx
self._rx = rx
self._op_id = op_id
Expand All @@ -221,6 +235,9 @@ def wait(self) -> bool:
assert _get(self._rx, self._timeout) == self._op_id
return True

def get_future(self) -> Future:
return self._pg._get_future(self._op_id)


class BabyWorkNCCL(BabyWork):
def wait(self) -> bool:
Expand Down Expand Up @@ -255,6 +272,8 @@ def __init__(self, timeout: float = 60.0) -> None:
self._p = None
self._tx = None
self._rx = None
self._future_queue = None
self._future_thread = None

self._timeout = timeout

Expand All @@ -264,20 +283,46 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

self._world_size = world_size

if self._tx is not None:
self._tx.close()
if self._rx is not None:
self._rx.close()
if self._future_queue is not None:
self._future_queue.put(_QUEUE_CLOSE)
self._future_queue.close()

ctx = mp.get_context("spawn")
self._tx = ctx.Queue()
self._rx = ctx.Queue()

# futures need thread to fire callbacks
self._future_queue = ctx.Queue()
# this lock needs to be held when manipulating _futures
self._futures_lock = threading.Lock()
self._futures = {}
self._future_thread = threading.Thread(
target=self._future_handler,
args=(self._future_queue,),
daemon=True,
)
self._future_thread.start()

self._p = ctx.Process(
target=self._worker,
args=(store_addr, rank, world_size, self._tx, self._rx),
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
daemon=True,
)
self._p.start()

@classmethod
def _worker(
cls, store_addr: str, rank: int, world_size: int, rx: mp.Queue, tx: mp.Queue
cls,
store_addr: str,
rank: int,
world_size: int,
rx: mp.Queue,
tx: mp.Queue,
future_queue: mp.Queue,
) -> None:
try:
store = create_store(store_addr)
Expand All @@ -300,6 +345,18 @@ def _worker(
work[op_id].wait()
del work[op_id]
tx.put(op_id)
elif cmd == "future":
op_id = op[1]

def callback(fut: Future):
try:
fut.wait()
future_queue.put((op_id, _FUTURE_RESULT, None))
except Exception as e:
future_queue.put((op_id, _FUTURE_EXCEPTION, e))

work[op_id].get_future().add_done_callback(callback)
tx.put(op_id)
elif cmd == "synchronize":
# CUDA only, use events instead of waiting on CPU
op_id = op[1]
Expand All @@ -322,12 +379,41 @@ def _worker(
logger.exception("worker errored")
tx.put(e)

def _future_handler(self, future_queue: mp.Queue) -> None:
try:
while True:
cmd = future_queue.get()
if cmd == _QUEUE_CLOSE:
break
op_id, mode, data = cmd
with self._futures_lock:
fut = self._futures[op_id]
del self._futures[op_id]
if mode == _FUTURE_RESULT:
fut.set_result(data)
elif mode == _FUTURE_EXCEPTION:
fut.set_exception(data)
else:
raise ValueError(f"unknown mode {mode}")
except Exception as e:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future:
with self._futures_lock:
fut = Future()
self._futures[op_id] = fut
self._tx.put(("future", op_id), timeout=self._timeout)

assert _get(self._rx, self._timeout) == op_id
# TODO: return correct tensor instead of None
return fut

def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
self._tx.put(("func", func, args, kwargs), timeout=self._timeout)
op_id = _get(self._rx, self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"
return self.WORK_CLASS(
tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
pg=self, tx=self._tx, rx=self._rx, op_id=op_id, timeout=self._timeout
)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
Expand Down
8 changes: 6 additions & 2 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_gloo(self) -> None:

a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()

m = nn.Linear(3, 4)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand All @@ -58,6 +59,7 @@ def test_nccl(self) -> None:
at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()

m = nn.Linear(3, 4).to(device)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand Down Expand Up @@ -95,7 +97,9 @@ def test_baby_gloo(self) -> None:
b_work = b.allreduce([bt], ReduceOp.SUM)

a_work.wait()
b_work.wait()
fut = b_work.get_future()

fut.wait()

torch.testing.assert_close(at, bt)

Expand Down Expand Up @@ -130,6 +134,6 @@ def test_baby_nccl(self) -> None:
b_work = b.allreduce([bt], ReduceOp.SUM)

a_work.wait()
b_work.wait()
b_work.get_future().wait()

torch.testing.assert_close(at, bt)

0 comments on commit 808e6fb

Please sign in to comment.