Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A FusionDefinition wrapper that takes/produces DTensors. #3703

Merged
merged 4 commits into from
Jan 22, 2025
Merged

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 14, 2025

This is a proof of concept for integrating nvFuser's model parallelism to the framework.

@wujingyue wujingyue marked this pull request as draft January 14, 2025 06:18
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

Cc @jjsjann123

wujingyue added a commit that referenced this pull request Jan 14, 2025
wujingyue added a commit that referenced this pull request Jan 14, 2025
Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit 40ce41b)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Potential Logic Change

The FusionDefinitionWrapper class introduces a new way of defining fusion without multidevice_schedule. This change may have implications on the existing logic and should be thoroughly reviewed.

class FusionDefinitionWrapper:
    def __init__(self, define_fusion: Callable[[FusionDefinition], None]):
        """Wraps a function that defines a fusion without `multidevice_schedule`."""
        self._define_fusion = define_fusion

    def _create_fusion_definition(
        self, in_dtensors: Iterable[DTensor]
    ) -> FusionDefinition:
        define_fn = self._define_fusion

        class Model(FusionDefinition):
            def definition(self) -> None:
                define_fn(self)

            def _find_tensor_by_index(self, index: int) -> nvfuser.Tensor:
                for t in self.sched.tensors():
                    if t.index == index:
                        return t
                return None

            def multidevice_schedule(self) -> None:
                for in_tensor_index, in_dtensor in zip(self.inputs(), in_dtensors):
                    in_tensor = self._find_tensor_by_index(in_tensor_index)

                    # Set the device mesh.
                    assert (
                        in_dtensor.device_mesh.ndim == 1
                    ), "nvFuser's Python API only supports 1D meshes."
                    mesh = nvfuser.DeviceMesh(in_dtensor.device_mesh.mesh.tolist())

                    self.sched._set_device_mesh(in_tensor, mesh)

                    # Parallelize.
                    assert len(in_dtensor.placements) == 1, "Expect a 1D mesh"
                    placement: Placement = in_dtensor.placements[0]
                    if placement.is_shard():
                        dim = cast(Shard, placement).dim
                        self.sched.parallelize(
                            in_tensor, dim, nvfuser.ParallelType.mesh_x
                        )

        return Model()

    def __call__(self, in_dtensors: Iterable[DTensor]) -> list[DTensor]:
        fusion_def = self._create_fusion_definition(in_dtensors)

        in_tensors = [in_dtensor.to_local() for in_dtensor in in_dtensors]
        out_tensors = fusion_def.execute(in_tensors)

        for i, out_tensor in enumerate(out_tensors):
            if isinstance(out_tensor, nvfuser.DistributedTensor):
                mesh = dist.device_mesh.init_device_mesh(
                    "cuda", (out_tensor.mesh.size,)
                )
                placements: list[Placement] = []
                for parallel_type in [nvfuser.ParallelType.mesh_x]:
                    axis: int = out_tensor.axis_sharded_on(parallel_type)
                    placements.append(Replicate() if axis == -1 else Shard(axis))
                out_tensors[i] = DTensor.from_local(out_tensor.local, mesh, placements)
        return out_tensors
Function Signature Change

The FusionDefinitionWrapper class has a new method _create_fusion_definition which changes the way fusion definitions are created. This change may affect the existing function signatures and should be reviewed.

def _create_fusion_definition(
    self, in_dtensors: Iterable[DTensor]
) -> FusionDefinition:
    define_fn = self._define_fusion

    class Model(FusionDefinition):
        def definition(self) -> None:
            define_fn(self)

        def _find_tensor_by_index(self, index: int) -> nvfuser.Tensor:
            for t in self.sched.tensors():
                if t.index == index:
                    return t
            return None

        def multidevice_schedule(self) -> None:
            for in_tensor_index, in_dtensor in zip(self.inputs(), in_dtensors):
                in_tensor = self._find_tensor_by_index(in_tensor_index)

                # Set the device mesh.
                assert (
                    in_dtensor.device_mesh.ndim == 1
                ), "nvFuser's Python API only supports 1D meshes."
                mesh = nvfuser.DeviceMesh(in_dtensor.device_mesh.mesh.tolist())

                self.sched._set_device_mesh(in_tensor, mesh)

                # Parallelize.
                assert len(in_dtensor.placements) == 1, "Expect a 1D mesh"
                placement: Placement = in_dtensor.placements[0]
                if placement.is_shard():
                    dim = cast(Shard, placement).dim
                    self.sched.parallelize(
                        in_tensor, dim, nvfuser.ParallelType.mesh_x
                    )

    return Model()
Potential Performance Impact

The FusionDefinitionWrapper class introduces a new way of parallelizing tensors which may have performance implications. This change should be reviewed to ensure it does not introduce any performance regressions.

for in_tensor_index, in_dtensor in zip(self.inputs(), in_dtensors):
    in_tensor = self._find_tensor_by_index(in_tensor_index)

    # Set the device mesh.
    assert (
        in_dtensor.device_mesh.ndim == 1
    ), "nvFuser's Python API only supports 1D meshes."
    mesh = nvfuser.DeviceMesh(in_dtensor.device_mesh.mesh.tolist())

    self.sched._set_device_mesh(in_tensor, mesh)

    # Parallelize.
    assert len(in_dtensor.placements) == 1, "Expect a 1D mesh"
    placement: Placement = in_dtensor.placements[0]
    if placement.is_shard():
        dim = cast(Shard, placement).dim
        self.sched.parallelize(
            in_tensor, dim, nvfuser.ParallelType.mesh_x
        )

@wujingyue wujingyue changed the title A custom op that wraps a FusionDefinition and takes/produces DTensors. A FusionDefinition wrapper that takes/produces DTensors. Jan 16, 2025
@wujingyue wujingyue changed the base branch from main to wjy/dist January 19, 2025 08:03
@wujingyue wujingyue marked this pull request as ready for review January 19, 2025 17:40
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

wujingyue commented Jan 21, 2025

cc @syed-ahmed FYI, I made nvFuser return output shardings as well as local tensors

Base automatically changed from wjy/dist to main January 22, 2025 04:47
@wujingyue
Copy link
Collaborator Author

!test

1 similar comment
@wujingyue
Copy link
Collaborator Author

!test

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

tests/python/test_dtensor.py Outdated Show resolved Hide resolved
rank = dist.get_rank()
torch.cuda.set_device(rank)

in_tensor = torch.randn(num_devices, 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: are we still required in nvfuser to have the shard dimension with size == num_of_shard?

Copy link
Collaborator Author

@wujingyue wujingyue Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are gradually adding DID loop split so eventually this won't be a requirement. See test_communication.py and several test_*_loop_splits in test_multidevice.py recently added by @Priya2698

placements: list[Placement] = []
for parallel_type in [nvfuser.ParallelType.mesh_x]:
axis: int = out_tensor.axis_sharded_on(parallel_type)
placements.append(Replicate() if axis == -1 else Shard(axis))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't matter here since we only have 1d device_mesh for now.

But the logic doesn't feel right in this for loop with for parallel_type in [nvfuser.ParallelType.mesh_x]:. i.e. I think we need to check the rank of out_tensor.mesh. Which is only 1-d. So we can't really write a future proof thing here.

^^^ I realized those aren't really constructive comments, I'm just trying to point out that we might need to put more thought on how we want to expose out_tensor.mesh. 😜

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. When nvFuser supports >1D mesh, we may want DeviceMesh to hold a torch.Tensor, similar to https://github.com/pytorch/pytorch/blob/3917053f63b75f14e3cb2f53805fad4ade5363df/torch/distributed/device_mesh.py#L420.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just some comments on FusionDefinitionWrapper.

tests/python/test_dtensor.py Outdated Show resolved Hide resolved
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue wujingyue merged commit 92fdf42 into main Jan 22, 2025
24 of 25 checks passed
@wujingyue wujingyue deleted the wjy/dtensor branch January 22, 2025 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants