-
Notifications
You must be signed in to change notification settings - Fork 462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TorchRec 2D Parallel #2554
TorchRec 2D Parallel #2554
Conversation
This pull request was exported from Phabricator. Differential Revision: D61643328 |
This pull request was exported from Phabricator. Differential Revision: D61643328 |
Summary: Pull Request resolved: pytorch#2554 In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Differential Revision: D61643328
b5d8eda
to
8ed3c32
Compare
Summary: In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Differential Revision: D61643328
8ed3c32
to
900a4cb
Compare
This pull request was exported from Phabricator. Differential Revision: D61643328 |
Summary: In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Differential Revision: D61643328
900a4cb
to
f106537
Compare
This pull request was exported from Phabricator. Differential Revision: D61643328 |
f106537
to
a91733a
Compare
Summary: Pull Request resolved: pytorch#2554 In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Differential Revision: D61643328
This pull request was exported from Phabricator. Differential Revision: D61643328 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D61643328 |
a91733a
to
a198f05
Compare
Summary: Pull Request resolved: pytorch#2554 In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Differential Revision: D61643328
a198f05
to
8413656
Compare
Summary: Pull Request resolved: pytorch#2554 In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Reviewed By: dstaay-fb Differential Revision: D61643328
This pull request was exported from Phabricator. Differential Revision: D61643328 |
Summary: Pull Request resolved: pytorch#2554 In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Reviewed By: dstaay-fb Differential Revision: D61643328
8413656
to
87c67f2
Compare
This pull request was exported from Phabricator. Differential Revision: D61643328 |
Summary: In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs. The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme. Example Use Case: Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: - Group 0, DMP 0: [0, 2, 4, 6] - Group 1, DMP 1: [1, 3, 5, 7] Each group receives an identical sharding plan for their local world size and ranks. If we have one table sharded in each DMP, with one shard on each rank in the group, each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results Reviewed By: dstaay-fb Differential Revision: D61643328
87c67f2
to
05be002
Compare
This pull request was exported from Phabricator. Differential Revision: D61643328 |
Summary:
In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name. This diff enables the pathway to scaling training on 4k+ GPUs
Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs.
The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth.
Example Use Case:
Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be:
- Group 0, DMP 0: [0, 2, 4, 6]
- Group 1, DMP 1: [1, 3, 5, 7]
NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results
Differential Revision: D61643328