Skip to content

Commit

Permalink
[Reshard] Support p to r on cross mesh (PaddlePaddle#59621)
Browse files Browse the repository at this point in the history
* fix: typo

* fix: typo

* feat: reshard p2r
  • Loading branch information
HermitSun authored and SigureMo committed Dec 5, 2023
1 parent 3353d28 commit 0496cd4
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 8 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ void BindAutoParallel(py::module *m) {
*m, "PToRReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::PToRReshardFunctionCrossMesh>(
*m, "PToRReshardFunctionCrossMesh", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SToSReshardFunction>(
*m, "SToSReshardFunction", ReshardFunction)
.def(py::init<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h"
#include "paddle/phi/kernels/all_reduce_kernel.h"
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
Expand Down Expand Up @@ -91,5 +92,46 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx,
SetDistProps(out, in.dims(), out_dist_attr);
}

bool PToRReshardFunctionCrossMesh::IsSuitable(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
const auto& in_dist_attr = in.dist_attr();

RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_partial());
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated());

const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh);

return true;
}

void PToRReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call PToRReshardFunctionCrossMesh Eval";
const auto& out_process_mesh = out_dist_attr.process_mesh();

DistTensor tmp_result;

SameStatusReshardFunction same_status_func;
TensorDistAttr tmp_dist_attr = in.dist_attr();
tmp_dist_attr.set_process_mesh(out_process_mesh);
same_status_func.Eval(dev_ctx, in, tmp_dist_attr, &tmp_result);

PToRReshardFunction p_to_r_func;
PADDLE_ENFORCE(
p_to_r_func.IsSuitable(tmp_result, out_dist_attr),
phi::errors::InvalidArgument(
"Invoke the p to r reshard function is not valid from %s to %s.",
tmp_result.dist_attr(),
out_dist_attr));
p_to_r_func.Eval(dev_ctx, tmp_result, out_dist_attr, out);
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,18 @@ class PToRReshardFunction final : public ReshardFunction {
std::string Name() override { return "PToRReshard"; }
};

class PToRReshardFunctionCrossMesh final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;

std::string Name() override { return "PToRReshardCrossMesh"; }
};

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ REGISTER_RESHARD_FUNC(RToSReshardFunctionCrossMesh);
REGISTER_RESHARD_FUNC(RToPReshardFunction);
REGISTER_RESHARD_FUNC(RToPReshardFunctionCrossMesh);
REGISTER_RESHARD_FUNC(PToRReshardFunction);
REGISTER_RESHARD_FUNC(PToRReshardFunctionCrossMesh);
REGISTER_RESHARD_FUNC(PToSReshardFunction);
REGISTER_RESHARD_FUNC(SToSReshardFunction);
REGISTER_RESHARD_FUNC(SameStatusReshardFunction);
Expand Down
4 changes: 2 additions & 2 deletions test/auto_parallel/reshard_p_to_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from paddle.framework import core


class TestReshardSToR:
class TestReshardPToR:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
Expand All @@ -47,4 +47,4 @@ def run_test_case(self):


if __name__ == '__main__':
TestReshardSToR().run_test_case()
TestReshardPToR().run_test_case()
54 changes: 54 additions & 0 deletions test/auto_parallel/reshard_p_to_r_cross_mesh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.base import core


class TestReshardPToRCrossMesh:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._shard = eval(os.getenv("shard"))
self._backend = os.getenv("backend")
self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"])

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())

dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)

input_tensor = dist.shard_tensor(
a, self._in_mesh, [dist.Partial(dist.ReduceType.kRedSum)]
)
out = dist.reshard(input_tensor, self._out_mesh, [dist.Replicate()])

assert np.equal(out.shape, input_tensor.shape).all()
np.testing.assert_equal(out._local_value().numpy(), a.numpy())


if __name__ == '__main__':
TestReshardPToRCrossMesh().run_test_case()
4 changes: 2 additions & 2 deletions test/auto_parallel/reshard_s_to_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import paddle.distributed as dist


class TestReshardSToR:
class TestReshardSToP:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
Expand Down Expand Up @@ -56,4 +56,4 @@ def run_test_case(self):


if __name__ == '__main__':
TestReshardSToR().run_test_case()
TestReshardSToP().run_test_case()
18 changes: 15 additions & 3 deletions test/auto_parallel/test_reshard_p_to_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@
import collective.test_communication_api_base as test_base


class TestReshardSToR(test_base.CommunicationTestDistBase):
class TestReshardPToR(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
"seeds": "2023",
}
self._changeable_envs = {
"shard": ["0", "1"],
"backend": ["cpu", "gpu"],
}

def test_reshard_s_to_r(self):
def test_reshard_p_to_r(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
Expand All @@ -39,6 +40,17 @@ def test_reshard_s_to_r(self):
user_defined_envs=envs,
)

def test_reshard_p_to_r_cross_mesh(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
if envs["backend"] != "cpu":
self.run_test_case(
"reshard_p_to_r_cross_mesh.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/auto_parallel/test_reshard_p_to_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import collective.test_communication_api_base as test_base


class TestReshardSToR(test_base.CommunicationTestDistBase):
class TestReshardPToS(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
Expand Down

0 comments on commit 0496cd4

Please sign in to comment.