Skip to content

Commit

Permalink
Collect output shardings from DistributedTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Jan 19, 2025
1 parent a57a73d commit ee9dbeb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
3 changes: 3 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def mesh(self) -> DeviceMesh:
def axis_sharded_on(self, parallel_type: ParallelType) -> int:
return self._dtensor.axis_sharded_on(parallel_type)

def local(self) -> torch.Tensor:
return self._dtensor.local()


class FusionDefinition(_C._FusionDefinition):
def __init__(self, id=None, max_length=1024):
Expand Down
30 changes: 17 additions & 13 deletions tests/python/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Iterable
from nvfuser import DataType, FusionDefinition
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Placement, Shard
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate
from typing import Callable, cast


Expand Down Expand Up @@ -74,14 +74,19 @@ def multidevice_schedule(self):
model = Model()
out_tensors = model.execute(in_tensors)

out_dtensors = []
for out_tensor in out_tensors:
# FIXME: we should collect output meshes/placements from nvFuser.
out_dtensor = DTensor.from_local(
out_tensor, in_dtensors[0].device_mesh, in_dtensors[0].placements
)
out_dtensors.append(out_dtensor)
return out_dtensors
for i, out_tensor in enumerate(out_tensors):
if isinstance(out_tensor, nvfuser.DistributedTensor):
mesh = dist.device_mesh.init_device_mesh(
"cuda", [out_tensor.mesh.size()]
)
placements: list[Placement] = []
for parallel_type in [nvfuser.ParallelType.mesh_x]:
axis: int = out_tensor.axis_sharded_on(parallel_type)
placements.append(Replicate() if axis == -1 else Shard(axis))
out_tensors[i] = DTensor.from_local(
out_tensor.local(), mesh, placements
)
return out_tensors


@pytest.mark.mpi
Expand All @@ -104,8 +109,7 @@ def define_fusion(fd: FusionDefinition):
mesh = dist.device_mesh.init_device_mesh("cuda", [num_devices])
in_dtensor = dist.tensor.distribute_tensor(in_tensor, mesh, [Shard(0)])

out_dtensors = op([in_dtensor])

assert len(out_dtensors) == 1
out_dtensor = out_dtensors[0]
(out_dtensor,) = op([in_dtensor])
torch.testing.assert_close(out_dtensor.to_local(), in_dtensor.to_local() + 1)
assert out_dtensor.device_mesh == in_dtensor.device_mesh
assert out_dtensor.placements == in_dtensor.placements

0 comments on commit ee9dbeb

Please sign in to comment.