Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds more collectives to ProcessGroups #108

Merged
merged 21 commits into from
Feb 12, 2025
Merged

Conversation

allenwang28
Copy link
Contributor

@allenwang28 allenwang28 commented Feb 11, 2025

What does this PR do?

Continuation of work for #97.

This PR:

  • Adds more collectives:
    • allreduce_coalesced
    • alltoall_base
    • barrier
    • reduce_scatter
    • send/recv
  • Extends process_group_test.py to accommodate the above collectives

Note - the process_group_test adds several new tests, increasing the test time from 16s to 22s. Not sure how avoidable this is - a prior version of this PR took 36s!

Concerns and possible follow up actions

Missing ops in backends

Notably allgather_into_tensor_coalesced and reduce_scatter_tensor_coalesced were not added in this PR.

It was part of the plan, but started seeing an error such as:

E       AttributeError: 'torch._C._distributed_c10d.ProcessGroupNCCL' object has no attribute 'allgather_into_tensor_coalesced'

This was confusing, but surely enough this is true, if we do dir(pg) we see:

['NCCLConfig', 'Options', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_add_ephemeral_timeout', '_allgather_base', '_end_coalescing', '_get_backend_name', '_get_sequence_number_for_group', '_group_end', '_group_start', '_is_initialized', '_pybind11_conduit_v1_', '_reduce_scatter_base', '_set_default_timeout', '_set_sequence_number_for_group', '_shutdown', '_start_coalescing', '_verify_work_timeout', 'abort', 'allgather', 'allgather_coalesced', 'allreduce', 'allreduce_coalesced', 'alltoall', 'alltoall_base', 'barrier', 'bound_device_id', 'broadcast', 'comm_split_count', 'deregister_mem_pool', 'eager_connect_single_device', 'gather', 'monitored_barrier', 'name', 'options', 'perform_nocolor_split', 'rank', 'recv', 'recv_anysource', 'reduce', 'reduce_scatter', 'register_mem_pool', 'scatter', 'send', 'size', 'supports_splitting', 'uid']

After some digging I realized it was because ProcessGroupNCCL (and other pg backends) inherit collectives that are defined here, which does not include allgather_into_tensor_coalesced or reduce_scatter_tensor_coalesced. This is confusing since we know that e.g. allgather_into_tensor_coalesced is implemented for the ProcessGroupNCCL backend.

There are likely several possible resolutions for this, so will defer for now for a future decision.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 11, 2025
@@ -122,6 +137,122 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
return works


def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
"""
Helper function to test a set of collective operations in settings with multiple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO I'd split these up into individual collective tests, so that it can be easier isolated when there are issues + users will be able to clearly understand which collectives are supported on which fault tolerant process groups.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with the sentiment! I previously tried something like this in #103 with _test_pg, but this introduced made the execution time much longer since there were now (# of collectives) times as many process groups that need to be spun up and torn down. That approach also added in more complexities that I wasn't comfortable with.

Ideally I could represent it like this:

class MultiPgTest:
  def test_all_gather(self):
    tensor_list = ...
    ...
  
  def test_allreduce(self):
    ...

but I'm not sure how I could re-use process groups in this way. Open to any suggestions though!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allenwang28 Couldn't you just create the process groups & subprocesses in a setupClass method which would allow them to be re-used across unittests?

TestDistBackend and MultiProcessTestCase in PT-D might provide some pointers for doing something like this: https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/distributed/distributed_test.py#L566

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI there is a test class called MultiProcContinuousTest defined in the same test utils file as MultiProcesTestCase that shares a PG across test instances. It requires having main defined differently for that test file and isn't compatible with hahving MultiProcesTestCase instnaces inside the same file currently, but it is in use in a number of pt-d tests bc it saves a lot of time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to use subTest to cleanly distinguish between the different scenarios being tested

I don't love the MultiProc tests since they're non-standard and pretty unintuitive in how they work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding in subTest into _test_multi_pg to see if it would delineate the scenarios being tested but unfortunately it doesn't seem to play well with ThreadPoolExecutor. subTest could be a good approach for _test_pg though.

# Test send/recv
if rank == 0:
send_tensor = tensor.clone()
send_work = pg.send([send_tensor], 1, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also test the actual fault tolerance, i.e. sender fails, sender succeeds but receiver fails, and verify that an appropriate exception / timeout is returned back to the user process?

And ideally, this is then retried with success after the PG gets reconfigured.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have some additional tests that specifically test failure scenarios but it's not a bad idea to test timeouts here

@allenwang28
Copy link
Contributor Author

@rohan-varma - these are great suggestions and I agree that these enhancements would improve the testing within torchft. However this is out of scope for this specific PR so I have created #109 to track this request. Thank you for the great feedback!

@allenwang28 allenwang28 marked this pull request as ready for review February 12, 2025 21:48
@allenwang28 allenwang28 requested a review from d4l3k February 12, 2025 21:56
@allenwang28 allenwang28 requested a review from wconstab February 12, 2025 22:16
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good aside form the small nits other folks already pointed out

@@ -122,6 +137,122 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
return works


def _test_multi_pg(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
"""
Helper function to test a set of collective operations in settings with multiple
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to use subTest to cleanly distinguish between the different scenarios being tested

I don't love the MultiProc tests since they're non-standard and pretty unintuitive in how they work

# Test send/recv
if rank == 0:
send_tensor = tensor.clone()
send_work = pg.send([send_tensor], 1, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have some additional tests that specifically test failure scenarios but it's not a bad idea to test timeouts here

@allenwang28 allenwang28 merged commit ca8c540 into pytorch:main Feb 12, 2025
6 checks passed
@allenwang28 allenwang28 deleted the coll branch February 12, 2025 23:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants