diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index ac4c64ff69d..dfc36170ecf 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -85,7 +85,7 @@ def multidevice_schedule(self): @pytest.mark.mpi -def test_execute_with_dtensors(setup_process_group): +def test_dtensor_plus_one(setup_process_group): def define_fusion(fd: FusionDefinition): inp = fd.define_tensor( (-1, -1), contiguity=(False, False), dtype=DataType.Float