Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProcessGroupBabyNCCL: support multiple streams and use event on start #91

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 137 additions & 46 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
import logging
import queue
import threading
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -58,9 +59,9 @@
BroadcastOptions,
ReduceOp,
Work,
_world,
)
from torch.futures import Future
from torch.utils._pytree import tree_any

if TYPE_CHECKING:
from torchft.manager import Manager
Expand Down Expand Up @@ -586,29 +587,52 @@ def __init__(
self._timeout = timeout

def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._pg._assert_alive()

self._tx.put(("wait", self._op_id), timeout=self._timeout)
assert _get(self._rx, self._timeout) == self._op_id
op_id, event = cast(
Tuple[int, Optional[torch.cuda.Event]],
_get(self._rx, timeout or self._timeout),
)
assert op_id == self._op_id
if event is not None:
event.wait()
return True

def synchronize(self) -> None:
# TODO: No one seems to use this and NCCL wait already only waits the
# stream and is non-blocking on the CPU side so no real need for a
# separate call.
raise NotImplementedError("not implemented")

def get_future(self) -> Future[object]:
return self._pg._get_future(self._op_id)

def __del__(self) -> None:
self._tx.put(("del", self._op_id), timeout=self._timeout)


class _BabyWorkNCCL(_BabyWork):
def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
# pyre-fixme[23]: unable to unpack into 2 values
op_id, event = _get(self._rx, self._timeout)
assert op_id == self._op_id
assert isinstance(event, torch.cuda.Event)
def _is_any_cuda(obj: object) -> bool:
"""
Returns true if any of the tensors in the object are CUDA tensors.

# Wait on Event makes the stream wait but not the CPU thread.
event.wait()
Supports lists, tuples, dicts, and tensors.
"""
return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj)

return True

@dataclass
class _OpMetadata:
work: Work
stream: Optional[torch.cuda.Stream]

@contextmanager
def set_stream(self) -> Generator[None, None, None]:
if self.stream is not None:
with torch.cuda.stream(self.stream):
yield
else:
yield


class ProcessGroupBaby(ProcessGroup):
Expand All @@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup):
subprocess. Since it's running in a subprocess all tensors need to be in
shared memory or will be moved to shared memory. CUDA tensors are implicitly
share able and don't need any changes.

"""

WORK_CLASS: Type[_BabyWork] = _BabyWork

def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
super().__init__(0, 1)

Expand Down Expand Up @@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:

self._p = ctx.Process(
target=self._worker,
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
args=(
store_addr,
rank,
world_size,
self._tx,
self._rx,
self._future_queue,
),
daemon=True,
)
self._p.start()
Expand Down Expand Up @@ -716,23 +744,70 @@ def _worker(
return
tx.put(None)

work = {}
streams: Dict[str, torch.cuda.Stream] = {}
work: Dict[int, _OpMetadata] = {}
next_op_id: int = 0

while True:
op = rx.get()
cmd = op[0]
if cmd == "func":
func_name, args, kwargs = op[1:]
args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)
work[next_op_id] = fn(*args, **kwargs)
func_name, args, kwargs, stream_device, stream_id, event = op[1:]

# To avoid potential deadlocks we need to preserve the
# stream/synchronization behavior of the parent process.
# We allocate one Stream per stream_id to make sure that we
# don't accidentally introduce cross stream synchronization
# points.
if stream_id is not None:
stream_key = f"{stream_device}/{stream_id}"
if stream_key not in streams:
streams[stream_key] = torch.cuda.Stream(
device=stream_device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to have zombie stream if there are multiple failures? Will this cause memory leakage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is streams are specific to the cuda context/process so this will be cleaned up just fine when it gets killed

stream = streams[stream_key]
else:
stream = None

with (
torch.cuda.stream(stream)
if stream is not None
else nullcontext()
):
# Make the stream wait on the cuda event to make sure we
# don't start the operation until the tensor is ready.
if event is not None:
event.wait()

args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)
work[next_op_id] = _OpMetadata(
work=fn(*args, **kwargs),
stream=stream,
)
tx.put(next_op_id)
next_op_id += 1
elif cmd == "wait":
op_id: int = op[1]
work[op_id].wait()
tx.put(op_id)

metadata = work[op_id]

with metadata.set_stream():
# With WorkNCCL this makes the stream wait not the CPU when
# no timeout is passed.
metadata.work.wait()

# Register event on the stream that we can pass to the main
# process.
event = (
torch.cuda.current_stream().record_event(
torch.cuda.Event(interprocess=True)
)
if metadata.stream is not None
else None
)

tx.put((op_id, event))
elif cmd == "del":
op_id: int = op[1]
del work[op_id]
Expand All @@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None:
except Exception as e:
future_queue.put((op_id, _FUTURE_EXCEPTION, e))

work[op_id].get_future().add_done_callback(callback)
work[op_id].work.get_future().add_done_callback(callback)
tx.put(op_id)
elif cmd == "synchronize":
# CUDA only, use events instead of waiting on CPU
op_id = op[1]

# With WorkNCCL this makes the stream wait not the CPU when
# no timeout is passed.
work[op_id].wait()

# Register event on the stream that we can pass to the main
# process.
event = torch.cuda.Event(interprocess=True)
event.record()

del work[op_id]
tx.put((op_id, event))
elif cmd == "num_active_work":
tx.put(len(work))
else:
Expand All @@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None:
except Exception as e:
logger.exception("worker errored")
tx.put(e)
raise

def _future_handler(self, future_queue: mp.Queue) -> None:
try:
Expand All @@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
logger.exception(f"got unexpected error in future handler: {e}")

def _get_future(self, op_id: int) -> Future[object]:
self._assert_alive()

with self._futures_lock:
fut = Future() # pyre-fixme[29]: is not a function
self._futures[op_id] = fut
Expand All @@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]:
return fut

def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
self._assert_alive()

rx = self._rx
tx = self._tx
assert rx is not None
assert tx is not None

is_cuda = _is_any_cuda(args)

stream_device = torch.cuda.current_stream().device if is_cuda else None
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
event = (
torch.cuda.current_stream().record_event(
torch.cuda.Event(interprocess=True)
)
if is_cuda
else None
)

tx.put(
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
(
"func",
func,
_PickleSafeOptions.safe_args(args),
kwargs,
stream_device,
stream_id,
event,
),
timeout=self._timeout,
)

op_id = _get(rx, self._timeout)
assert isinstance(op_id, int), f"invalid return {op_id}"

return self.WORK_CLASS(
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
)
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)

def _assert_alive(self) -> None:
"""
Assert that the process group is alive. This is used to ensure that
operations are not performed on a dead process group and any errors are surfaced.
"""
p = self._p
assert p is not None
if not p.is_alive():
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")

def allreduce(
self,
Expand Down Expand Up @@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
tensors may leak in the current PyTorch implementation. TODO fix
"""

WORK_CLASS = _BabyWorkNCCL

@classmethod
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
# pyre-fixme[16]: no attribute ProcessGroupNCCL
Expand Down
34 changes: 31 additions & 3 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,31 @@ def test_baby_gloo_apis(self) -> None:

self.assertEqual(a.num_active_work(), 0)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipUnless(torch.cuda.is_available(), "needs CUDA")
def test_baby_nccl_apis(self) -> None:
# set to 1 if more than >=2 gpus
device_id = 1 % torch.cuda.device_count()
torch.cuda.set_device(device_id)

store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)

store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
a.configure(store_addr, 0, 1)

_test_pg(a, torch.randn((2, 3), device="cuda"))

torch.cuda.synchronize()

# force collection to ensure no BabyWork objects remain
gc.collect()

self.assertEqual(a.num_active_work(), 0)

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand All @@ -282,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None:
store_addr: str = f"localhost:{store.port}/prefix"

def run(rank: int) -> Tuple[torch.Tensor, Work]:
a = ProcessGroupBabyNCCL()
a = ProcessGroupBabyNCCL(
timeout=timedelta(seconds=10.0),
)
a.configure(store_addr, rank, 2)

self.assertEqual(a.size(), 2)

at = torch.tensor([rank + 1], device=f"cuda:{rank}")
# We test using set_device to ensure stream device is correct.
torch.cuda.set_device(rank)
at = torch.tensor([rank + 1], device="cuda")

a_work = a.allreduce([at], ReduceOp.SUM)
return at, a_work
Expand Down