-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds more collectives to ProcessGroups #108
Conversation
@@ -122,6 +137,122 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] | |||
return works | |||
|
|||
|
|||
def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: | |||
""" | |||
Helper function to test a set of collective operations in settings with multiple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO I'd split these up into individual collective tests, so that it can be easier isolated when there are issues + users will be able to clearly understand which collectives are supported on which fault tolerant process groups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree with the sentiment! I previously tried something like this in #103 with _test_pg
, but this introduced made the execution time much longer since there were now (# of collectives) times as many process groups that need to be spun up and torn down. That approach also added in more complexities that I wasn't comfortable with.
Ideally I could represent it like this:
class MultiPgTest:
def test_all_gather(self):
tensor_list = ...
...
def test_allreduce(self):
...
but I'm not sure how I could re-use process groups in this way. Open to any suggestions though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 Couldn't you just create the process groups & subprocesses in a setupClass
method which would allow them to be re-used across unittests?
TestDistBackend and MultiProcessTestCase in PT-D might provide some pointers for doing something like this: https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/distributed/distributed_test.py#L566
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI there is a test class called MultiProcContinuousTest
defined in the same test utils file as MultiProcesTestCase that shares a PG across test instances. It requires having main defined differently for that test file and isn't compatible with hahving MultiProcesTestCase instnaces inside the same file currently, but it is in use in a number of pt-d tests bc it saves a lot of time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to use subTest
to cleanly distinguish between the different scenarios being tested
I don't love the MultiProc tests since they're non-standard and pretty unintuitive in how they work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried adding in subTest
into _test_multi_pg
to see if it would delineate the scenarios being tested but unfortunately it doesn't seem to play well with ThreadPoolExecutor
. subTest
could be a good approach for _test_pg
though.
# Test send/recv | ||
if rank == 0: | ||
send_tensor = tensor.clone() | ||
send_work = pg.send([send_tensor], 1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also test the actual fault tolerance, i.e. sender fails, sender succeeds but receiver fails, and verify that an appropriate exception / timeout is returned back to the user process?
And ideally, this is then retried with success after the PG gets reconfigured.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have some additional tests that specifically test failure scenarios but it's not a bad idea to test timeouts here
@rohan-varma - these are great suggestions and I agree that these enhancements would improve the testing within torchft. However this is out of scope for this specific PR so I have created #109 to track this request. Thank you for the great feedback! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good aside form the small nits other folks already pointed out
@@ -122,6 +137,122 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] | |||
return works | |||
|
|||
|
|||
def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None: | |||
""" | |||
Helper function to test a set of collective operations in settings with multiple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to use subTest
to cleanly distinguish between the different scenarios being tested
I don't love the MultiProc tests since they're non-standard and pretty unintuitive in how they work
# Test send/recv | ||
if rank == 0: | ||
send_tensor = tensor.clone() | ||
send_work = pg.send([send_tensor], 1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have some additional tests that specifically test failure scenarios but it's not a bad idea to test timeouts here
What does this PR do?
Continuation of work for #97.
This PR:
process_group_test.py
to accommodate the above collectivesNote - the
process_group_test
adds several new tests, increasing the test time from 16s to 22s. Not sure how avoidable this is - a prior version of this PR took 36s!Concerns and possible follow up actions
Missing ops in backends
Notably
allgather_into_tensor_coalesced
andreduce_scatter_tensor_coalesced
were not added in this PR.It was part of the plan, but started seeing an error such as:
This was confusing, but surely enough this is true, if we do
dir(pg)
we see:After some digging I realized it was because
ProcessGroupNCCL
(and other pg backends) inherit collectives that are defined here, which does not includeallgather_into_tensor_coalesced
orreduce_scatter_tensor_coalesced
. This is confusing since we know that e.g.allgather_into_tensor_coalesced
is implemented for theProcessGroupNCCL
backend.There are likely several possible resolutions for this, so will defer for now for a future decision.