Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add paddle.distributed.shard layer api #57604

Merged
merged 29 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bfca775
def dtensor_from_fn first edition
yangxiaoyu14 Aug 23, 2023
d20a032
dtensor_from_fn first edition
yangxiaoyu14 Aug 23, 2023
9070579
Merge branch 'develop' of https://github.com/yangxiaoyu14/Paddle into…
yangxiaoyu14 Aug 23, 2023
9f0fbe8
shard_layer api and utest(temporarily unavailable)
yangxiaoyu14 Aug 31, 2023
45f61c9
merge conflict
yangxiaoyu14 Aug 31, 2023
06424d1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Sep 1, 2023
c9cafaa
shard_layer API and unit test preliminary complete
yangxiaoyu14 Sep 1, 2023
9253ab2
complete the sample code modification according to ZhongKai's suggestion
yangxiaoyu14 Sep 1, 2023
42f0c0e
modify according to the review
yangxiaoyu14 Sep 4, 2023
0e43b0e
modify according to LiangGe's review
yangxiaoyu14 Sep 6, 2023
e480ea8
Not approved yet, temporarily stored
yangxiaoyu14 Sep 11, 2023
a8f69fb
Not approved yet, temporarily store
yangxiaoyu14 Sep 11, 2023
0987e23
waiting for tensor to param
yangxiaoyu14 Sep 13, 2023
3a631a6
20230913
yangxiaoyu14 Sep 13, 2023
0a6d33f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yangxiaoyu14 Sep 14, 2023
4755eb1
Complete the modifications according to Weihang's review
yangxiaoyu14 Sep 14, 2023
0483485
resolve conflict with develop
chenwhql Sep 18, 2023
5f63cec
polish shard_layer api impl and doc
chenwhql Sep 19, 2023
9af38fa
add shard layer test
chenwhql Sep 19, 2023
0bd32b4
resolve conflict with develop
chenwhql Sep 19, 2023
9d42bdd
rewrite unittest
chenwhql Sep 21, 2023
33466ae
revert needless change
chenwhql Sep 21, 2023
36b09e0
polish doc
chenwhql Sep 21, 2023
73958c0
add unittest for coverage
chenwhql Sep 22, 2023
22ac971
add static branch and test
chenwhql Sep 22, 2023
4b871b7
polish en doc
chenwhql Sep 22, 2023
bcf194b
polish test details
chenwhql Sep 22, 2023
c911620
verify doc test demo
chenwhql Sep 22, 2023
00dd04e
Update python/paddle/distributed/auto_parallel/api.py
chenwhql Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from .auto_parallel.api import shard_tensor # noqa: F401
from .auto_parallel.api import dtensor_from_fn # noqa: F401
from .auto_parallel.api import reshard # noqa: F401
from .auto_parallel.api import shard_layer # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

Expand Down Expand Up @@ -130,4 +131,5 @@
"shard_tensor",
"dtensor_from_fn",
"reshard",
"shard_layer",
]
164 changes: 164 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.base.framework import EagerParamBase
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
Expand All @@ -24,6 +27,8 @@
# Some APIs have the same name with the previous APIs implementation, which are
# a temporary state, and the APIs here will eventually be used.

# Part1: Shard attributes related APIs


class DistAttr(core.TensorDistAttr):
"""
Expand Down Expand Up @@ -83,6 +88,9 @@ def sharding_specs(self):
return self._sharding_specs


# Part2: DistTensor construction related APIs


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
):
Expand Down Expand Up @@ -180,6 +188,9 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
return shard_tensor(tensor, dist_attr=dist_attr)


# Part3: Data conversion related APIs


def reshard(dist_tensor, dist_attr):
"""
Reshard a distributed ``paddle.Tensor`` with given distributed attributes.
Expand Down Expand Up @@ -223,3 +234,156 @@ def reshard(dist_tensor, dist_attr):
raise RuntimeError(
"paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later."
)


def shard_layer(
layer: nn.Layer,
process_mesh: dist.ProcessMesh,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_layer需要支持跨mesh的情况吗?不支持的话,replicate_layer_params_and_buffers需不需要检查TensorProcessMesh是否合法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上不应该限制是否跨mesh,不过当用户跨mesh shard_layer的时候,就不需要调用reshard了,这块在用户使用行为上还需要讨论下

不过,这里不需要检查mesh具体的状态,因为shard_layer是shard_tensor的包装,要检查的话在shard_tensor中检查更加合适,shard_tensor检查不合法就报错

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明白了~

shard_fn: Callable = None,
input_fn: Callable = None,
output_fn: Callable = None,
) -> nn.Layer:
"""
Converts all layer's parameters to DistTensor parameters according to
the `shard_fn` specified. It could also control the conversion of input
or output of the layer by specifying the `input_fn` and `output_fn`.
(i.e. convert the input to `paddle.Tensor` with DistTensor, convert output
back to `paddle.Tensor` with DenseTensor.)

The `shard_fn` should have the following signature:

def shard_fn(layer_name, layer, process_mesh) -> None

The `input_fn` should have the following signature:

def input_fn(inputs, process_mesh) -> list(paddle.Tensor)

In general, the type of `input_fn` return value is paddle.Tensor with DistTensor.

The `output_fn` should have the following signature:

def output_fn(outputs, process_mesh) -> list(paddle.Tensor)

In general, the type of `output_fn` return value is paddle.Tensor with DenseTensor.

Args:
layer (paddle.nn.Layer): The Layer object to be shard.
process_mesh (paddle.distributed.ProcessMesh): The `ProcessMesh` information
to be place the input `layer`.
shard_fn (Callable): The function to shard layer parameters across
the `process_mesh`. If not specified, by default we replicate
all parameters of the layer across the `process_mesh`.
input_fn (Callable): Specify how the input of the layer is sharded.
The `input_fn` will be registered for the Layer as a `forward pre-hook`.
By default we do not shard the input.
output_fn (Callable): Specify how the output of the layer is sharded or
convert it back to `paddle.Tensor` with DenseTensor.
The `output_fn` will be registered for the Layer as `forward post-hook`.
By default we do not shard or convert the output.
Returns:
Layer: A layer that contains parameters/buffers
that are all `paddle.Tensor` with DistTensor

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))

>>> def shard_fn(layer_name, layer, process_mesh):
... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None])
... if layer_name == 'fc1':
... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr)

>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
>>> print(layer)

>>> # This case need to be excuted in multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py
"""
# Ensure that process_mesh is not an empty object
if process_mesh is None:
raise ValueError("The argument `process_mesh` cannot be empty.")

# Check the legality of process_mesh
if not isinstance(process_mesh, dist.ProcessMesh):
raise ValueError(
"The argument `process_mesh` is not `dist.ProcessMesh` type."
)

def replicate_layer_params_and_buffers(
layer: nn.Layer, mesh: dist.ProcessMesh
) -> None:
for key, param in layer._parameters.items():
if param is not None and not param.is_dist():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果参数不是dist_tensor,就把它变成replicated的,如果我们后面再执行机制里加了转replicated的操作,这里还需要吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面API都是inplace转replicated的话,这里应该可以删除,效果一致。

不过也不冲突,前置转其实逻辑上更顺畅一些,inplace转多少还是隐式更改了用户的输入

replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(param.shape))],
)
layer.add_parameter(
key,
shard_tensor(param, dist_attr=replicated_dist_attr),
)
else:
# do nothing, the dist parameters has already been shard by shard_fn
pass
for key, buffer in layer._buffers.items():
if buffer is not None and not buffer.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(buffer.shape))],
)
layer.register_buffer(
key,
shard_tensor(buffer, dist_attr=replicated_dist_attr),
)
else:
# do nothing, the dist buffers has already been shard by shard_fn
pass

if paddle.in_dynamic_mode():
if shard_fn is None:
# if shard_fn not specified, by default replicate
# all layer's parameters and buffers
for name, sublayers in layer.named_sublayers(include_self=True):
replicate_layer_params_and_buffers(sublayers, process_mesh)
else:
# apply shard_fn to sublayers, contains self
for name, sublayers in layer.named_sublayers(include_self=True):
shard_fn(name, sublayers, process_mesh)
# shard_fn may not deal with all parameters and buffers,
# the parameters and buffers that are not shard by shard_fn
# still need to be shard to replicated
replicate_layer_params_and_buffers(sublayers, process_mesh)

# register input_fn as layer's forward pre hook
if input_fn is not None:
layer.register_forward_pre_hook(
lambda _, inputs: input_fn(inputs, process_mesh)
)
# register output_fn as layer's forward post hook
if output_fn is not None:
layer.register_forward_post_hook(
lambda _, inputs, outputs: output_fn(outputs, process_mesh)
)

return layer
else:
# TODO(chenweihang): Support static mode branch later.
raise NotImplementedError(
"`paddle.distributed.shard_layer` only supports dynamic graph mode "
"now. It will be supported for static graph mode later."
)
1 change: 1 addition & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_tensor MODULES test_dist_tensor)
py_test_modules(test_api_dist_branch MODULES test_api_dist_branch)
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api)
py_test_modules(test_shard_layer_api MODULES test_shard_layer_api)
py_test_modules(test_cost_interface MODULES test_cost_interface)
# End of unittests WITH single card WITHOUT timeout

Expand Down
174 changes: 174 additions & 0 deletions test/auto_parallel/test_shard_layer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
import paddle.distributed as dist
from paddle import nn


# TODO(chenweihang): test for paddle nn Layer API
class DemoLayer(nn.Layer):
def __init__(self, num_features):
super().__init__()
self.w0 = self.create_parameter(shape=[num_features, num_features])
self.w1 = self.create_parameter(shape=[num_features, num_features])

def forward(self, x):
y = paddle.matmul(x, self.w0)
z = paddle.matmul(y, self.w1)
return z


class MyLayer(nn.Layer):
def __init__(self, num_features, num_layers):
super().__init__()
self.seq = nn.Sequential(
*[DemoLayer(num_features) for _ in range(num_layers)]
)

def forward(self, x):
return self.seq(x)


class TestShardLayer(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测,没有用launch启,但是可以用两卡吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里没实际用mesh切分,下面的spec都是None,运行时tensor都是replicated,切分会跳过,所以单卡也能跑,主要是测了一遍流程

self.num_features = 10
self.num_layers = 10

def test_shard_layer_base(self):
layer = MyLayer(self.num_features, self.num_layers)

def shard_fn(layer_name, layer, process_mesh):
if isinstance(layer, nn.Linear):
for name, param in layer.named_parameters():
if 'weight' in name:
dist_param = dist.shard_tensor(
param,
dist_attr=dist.DistAttr(
mesh=process_mesh, sharding_specs=[None, None]
),
)
else:
dist_param = dist.shard_tensor(
param,
dist_attr=dist.DistAttr(
mesh=process_mesh, sharding_specs=[None]
),
)
layer.add_parameter(name, dist_param)

# test shard parameters
sharded_params_layer = dist.shard_layer(layer, self.mesh, shard_fn)

for param in sharded_params_layer.parameters():
self.assertTrue(param.is_dist())
for x in param.dist_attr.dims_mapping:
self.assertEqual(x, -1)

# test shard buffers
test_buffer = paddle.randn([10])
layer.register_buffer("test_buffer", test_buffer, persistable=True)
sharded_buffers_layer = dist.shard_layer(layer, self.mesh, shard_fn)
self.assertTrue(sharded_buffers_layer.test_buffer.is_dist())
self.assertEqual(
sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1]
)

def test_shard_layer_input_fn_and_output_fn(self):
layer = MyLayer(self.num_features, self.num_layers)

def input_fn(inputs, process_mesh):
return dist.shard_tensor(
inputs[0], dist_attr=dist.DistAttr(process_mesh, [None, None])
)

def output_fn(outputs, process_mesh):
assert outputs.is_dist()
# TODO(chenweihang): replace by dist.unshard_dtensor later
return paddle.to_tensor(outputs.numpy())

# test shard parameters
replicate_params_layer = dist.shard_layer(
layer, self.mesh, input_fn=input_fn, output_fn=output_fn
)

x = paddle.randn([5, self.num_features])
dense_out = replicate_params_layer(x)
self.assertTrue(dense_out.is_dense())

for param in replicate_params_layer.parameters():
self.assertTrue(param.is_dist())
for x in param.dist_attr.dims_mapping:
self.assertEqual(x, -1)

# test shard buffers
test_buffer = paddle.randn([10])
layer.register_buffer("test_buffer", test_buffer, persistable=True)
sharded_buffers_layer = dist.shard_layer(
layer, self.mesh, input_fn=input_fn, output_fn=output_fn
)
self.assertTrue(sharded_buffers_layer.test_buffer.is_dist())
self.assertEqual(
sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1]
)

def test_process_mesh_argument_error(self):
layer = MyLayer(self.num_features, self.num_layers)

exception = None
try:
dist.shard_layer(layer, None)
except ValueError as ex:
self.assertIn(
"The argument `process_mesh` cannot be empty",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)

exception = None
try:
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=[None, None]
)
dist.shard_layer(layer, dist_attr)
except ValueError as ex:
self.assertIn(
"The argument `process_mesh` is not `dist.ProcessMesh` type",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)

def test_shard_layer_static_mode(self):
paddle.enable_static()
layer = MyLayer(self.num_features, self.num_layers)

exception = None
try:
dist.shard_layer(layer, self.mesh)
except NotImplementedError as ex:
self.assertIn(
"`paddle.distributed.shard_layer` only supports dynamic graph mode now",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)


if __name__ == '__main__':
unittest.main()