-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A FusionDefinition wrapper that takes/produces DTensors. #3703
Conversation
b81931f
to
0fa89cc
Compare
!test |
Cc @jjsjann123 |
0fa89cc
to
4d6d2ed
Compare
PR Reviewer Guide 🔍(Review updated until commit 40ce41b)Here are some key observations to aid the review process:
|
ee9dbeb
to
c8e49a7
Compare
c8e49a7
to
c3124a6
Compare
!test |
cc @syed-ahmed FYI, I made nvFuser return output shardings as well as local tensors |
f9082a9
to
856cd79
Compare
!test |
1 similar comment
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
rank = dist.get_rank() | ||
torch.cuda.set_device(rank) | ||
|
||
in_tensor = torch.randn(num_devices, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: are we still required in nvfuser to have the shard dimension with size == num_of_shard
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are gradually adding DID loop split so eventually this won't be a requirement. See test_communication.py and several test_*_loop_split
s in test_multidevice.py recently added by @Priya2698
placements: list[Placement] = [] | ||
for parallel_type in [nvfuser.ParallelType.mesh_x]: | ||
axis: int = out_tensor.axis_sharded_on(parallel_type) | ||
placements.append(Replicate() if axis == -1 else Shard(axis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it doesn't matter here since we only have 1d device_mesh for now.
But the logic doesn't feel right in this for loop with for parallel_type in [nvfuser.ParallelType.mesh_x]:
. i.e. I think we need to check the rank of out_tensor.mesh. Which is only 1-d. So we can't really write a future proof thing here.
^^^ I realized those aren't really constructive comments, I'm just trying to point out that we might need to put more thought on how we want to expose out_tensor.mesh. 😜
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. When nvFuser supports >1D mesh, we may want DeviceMesh to hold a torch.Tensor, similar to https://github.com/pytorch/pytorch/blob/3917053f63b75f14e3cb2f53805fad4ade5363df/torch/distributed/device_mesh.py#L420.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just some comments on FusionDefinitionWrapper
.
!test |
!test |
This is a proof of concept for integrating nvFuser's model parallelism to the framework.