diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 1764971900c9e..49893bac9eb45 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -213,6 +213,10 @@ void BindAutoParallel(py::module *m) { *m, "PToRReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "PToRReshardFunctionCrossMesh", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "SToSReshardFunction", ReshardFunction) .def(py::init<>()); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc index e0205b98c8c5c..34b05fdd70307 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc @@ -19,6 +19,7 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" #include "paddle/phi/kernels/all_reduce_kernel.h" #include "paddle/phi/kernels/elementwise_divide_kernel.h" #include "paddle/phi/kernels/full_kernel.h" @@ -91,5 +92,46 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, SetDistProps(out, in.dims(), out_dist_attr); } +bool PToRReshardFunctionCrossMesh::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_partial()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + + return true; +} + +void PToRReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call PToRReshardFunctionCrossMesh Eval"; + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + DistTensor tmp_result; + + SameStatusReshardFunction same_status_func; + TensorDistAttr tmp_dist_attr = in.dist_attr(); + tmp_dist_attr.set_process_mesh(out_process_mesh); + same_status_func.Eval(dev_ctx, in, tmp_dist_attr, &tmp_result); + + PToRReshardFunction p_to_r_func; + PADDLE_ENFORCE( + p_to_r_func.IsSuitable(tmp_result, out_dist_attr), + phi::errors::InvalidArgument( + "Invoke the p to r reshard function is not valid from %s to %s.", + tmp_result.dist_attr(), + out_dist_attr)); + p_to_r_func.Eval(dev_ctx, tmp_result, out_dist_attr, out); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h index 746baacf25a51..8ff729348f153 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h @@ -35,5 +35,18 @@ class PToRReshardFunction final : public ReshardFunction { std::string Name() override { return "PToRReshard"; } }; +class PToRReshardFunctionCrossMesh final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; + + std::string Name() override { return "PToRReshardCrossMesh"; } +}; + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function_registry.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function_registry.cc index c69a35f774429..5ea7507b947a8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function_registry.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_function_registry.cc @@ -58,6 +58,7 @@ REGISTER_RESHARD_FUNC(RToSReshardFunction); REGISTER_RESHARD_FUNC(RToSReshardFunctionCrossMesh); REGISTER_RESHARD_FUNC(RToPReshardFunction); REGISTER_RESHARD_FUNC(PToRReshardFunction); +REGISTER_RESHARD_FUNC(PToRReshardFunctionCrossMesh); REGISTER_RESHARD_FUNC(PToSReshardFunction); REGISTER_RESHARD_FUNC(SToSReshardFunction); REGISTER_RESHARD_FUNC(SameStatusReshardFunction); diff --git a/test/auto_parallel/reshard_p_to_r.py b/test/auto_parallel/reshard_p_to_r.py index 9bfce03b83868..2aae0ac7233b0 100644 --- a/test/auto_parallel/reshard_p_to_r.py +++ b/test/auto_parallel/reshard_p_to_r.py @@ -21,7 +21,7 @@ from paddle.framework import core -class TestReshardSToR: +class TestReshardPToR: def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") @@ -47,4 +47,4 @@ def run_test_case(self): if __name__ == '__main__': - TestReshardSToR().run_test_case() + TestReshardPToR().run_test_case() diff --git a/test/auto_parallel/reshard_p_to_r_cross_mesh.py b/test/auto_parallel/reshard_p_to_r_cross_mesh.py new file mode 100644 index 0000000000000..0ded27e369d2e --- /dev/null +++ b/test/auto_parallel/reshard_p_to_r_cross_mesh.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardPToRCrossMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + a = paddle.ones(self._shape) + + input_tensor = dist.shard_tensor( + a, self._in_mesh, [dist.Partial(dist.ReduceType.kRedSum)] + ) + out = dist.reshard(input_tensor, self._out_mesh, [dist.Replicate()]) + + assert np.equal(out.shape, input_tensor.shape).all() + np.testing.assert_equal(out._local_value().numpy(), a.numpy()) + + +if __name__ == '__main__': + TestReshardPToRCrossMesh().run_test_case() diff --git a/test/auto_parallel/reshard_s_to_p.py b/test/auto_parallel/reshard_s_to_p.py index f842b8ade0604..3614d4a46c72c 100644 --- a/test/auto_parallel/reshard_s_to_p.py +++ b/test/auto_parallel/reshard_s_to_p.py @@ -20,7 +20,7 @@ import paddle.distributed as dist -class TestReshardSToR: +class TestReshardSToP: def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") @@ -56,4 +56,4 @@ def run_test_case(self): if __name__ == '__main__': - TestReshardSToR().run_test_case() + TestReshardSToP().run_test_case() diff --git a/test/auto_parallel/test_reshard_p_to_r.py b/test/auto_parallel/test_reshard_p_to_r.py index d4e6ca39fc888..b76f1479397b3 100644 --- a/test/auto_parallel/test_reshard_p_to_r.py +++ b/test/auto_parallel/test_reshard_p_to_r.py @@ -17,19 +17,20 @@ import collective.test_communication_api_base as test_base -class TestReshardSToR(test_base.CommunicationTestDistBase): +class TestReshardPToR(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { "shape": "(10, 20)", "dtype": "float32", - "seeds": str(self._seeds), + "seeds": "2023", } self._changeable_envs = { + "shard": ["0", "1"], "backend": ["cpu", "gpu"], } - def test_reshard_s_to_r(self): + def test_reshard_p_to_r(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs ) @@ -39,6 +40,17 @@ def test_reshard_s_to_r(self): user_defined_envs=envs, ) + def test_reshard_p_to_r_cross_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + if envs["backend"] != "cpu": + self.run_test_case( + "reshard_p_to_r_cross_mesh.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/auto_parallel/test_reshard_p_to_s.py b/test/auto_parallel/test_reshard_p_to_s.py index bd55e5355479b..7c65627a3a3cc 100644 --- a/test/auto_parallel/test_reshard_p_to_s.py +++ b/test/auto_parallel/test_reshard_p_to_s.py @@ -17,7 +17,7 @@ import collective.test_communication_api_base as test_base -class TestReshardSToR(test_base.CommunicationTestDistBase): +class TestReshardPToS(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = {