-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Add paddle.distributed.shard layer api #57604
Changes from all commits
bfca775
d20a032
9070579
9f0fbe8
45f61c9
06424d1
c9cafaa
9253ab2
42f0c0e
0e43b0e
e480ea8
a8f69fb
0987e23
3a631a6
0a6d33f
4755eb1
0483485
5f63cec
9af38fa
0bd32b4
9d42bdd
33466ae
36b09e0
73958c0
22ac971
4b871b7
bcf194b
c911620
00dd04e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,11 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Callable | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle import nn | ||
from paddle.base.framework import EagerParamBase | ||
from paddle.distributed.auto_parallel.interface import ( | ||
shard_tensor as shard_tensor_static, | ||
|
@@ -24,6 +27,8 @@ | |
# Some APIs have the same name with the previous APIs implementation, which are | ||
# a temporary state, and the APIs here will eventually be used. | ||
|
||
# Part1: Shard attributes related APIs | ||
|
||
|
||
class DistAttr(core.TensorDistAttr): | ||
""" | ||
|
@@ -83,6 +88,9 @@ def sharding_specs(self): | |
return self._sharding_specs | ||
|
||
|
||
# Part2: DistTensor construction related APIs | ||
|
||
|
||
def shard_tensor( | ||
data, dtype=None, place=None, stop_gradient=True, dist_attr=None | ||
): | ||
|
@@ -180,6 +188,9 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs): | |
return shard_tensor(tensor, dist_attr=dist_attr) | ||
|
||
|
||
# Part3: Data conversion related APIs | ||
|
||
|
||
def reshard(dist_tensor, dist_attr): | ||
""" | ||
Reshard a distributed ``paddle.Tensor`` with given distributed attributes. | ||
|
@@ -223,3 +234,156 @@ def reshard(dist_tensor, dist_attr): | |
raise RuntimeError( | ||
"paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later." | ||
) | ||
|
||
|
||
def shard_layer( | ||
layer: nn.Layer, | ||
process_mesh: dist.ProcessMesh, | ||
shard_fn: Callable = None, | ||
input_fn: Callable = None, | ||
output_fn: Callable = None, | ||
) -> nn.Layer: | ||
""" | ||
Converts all layer's parameters to DistTensor parameters according to | ||
the `shard_fn` specified. It could also control the conversion of input | ||
or output of the layer by specifying the `input_fn` and `output_fn`. | ||
(i.e. convert the input to `paddle.Tensor` with DistTensor, convert output | ||
back to `paddle.Tensor` with DenseTensor.) | ||
|
||
The `shard_fn` should have the following signature: | ||
|
||
def shard_fn(layer_name, layer, process_mesh) -> None | ||
|
||
The `input_fn` should have the following signature: | ||
|
||
def input_fn(inputs, process_mesh) -> list(paddle.Tensor) | ||
|
||
In general, the type of `input_fn` return value is paddle.Tensor with DistTensor. | ||
|
||
The `output_fn` should have the following signature: | ||
|
||
def output_fn(outputs, process_mesh) -> list(paddle.Tensor) | ||
|
||
In general, the type of `output_fn` return value is paddle.Tensor with DenseTensor. | ||
|
||
Args: | ||
layer (paddle.nn.Layer): The Layer object to be shard. | ||
process_mesh (paddle.distributed.ProcessMesh): The `ProcessMesh` information | ||
to be place the input `layer`. | ||
shard_fn (Callable): The function to shard layer parameters across | ||
the `process_mesh`. If not specified, by default we replicate | ||
all parameters of the layer across the `process_mesh`. | ||
input_fn (Callable): Specify how the input of the layer is sharded. | ||
The `input_fn` will be registered for the Layer as a `forward pre-hook`. | ||
By default we do not shard the input. | ||
output_fn (Callable): Specify how the output of the layer is sharded or | ||
convert it back to `paddle.Tensor` with DenseTensor. | ||
The `output_fn` will be registered for the Layer as `forward post-hook`. | ||
By default we do not shard or convert the output. | ||
Returns: | ||
Layer: A layer that contains parameters/buffers | ||
that are all `paddle.Tensor` with DistTensor | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
>>> import paddle.distributed as dist | ||
|
||
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
|
||
>>> class MLP(paddle.nn.Layer): | ||
... def __init__(self): | ||
... super().__init__() | ||
... self.fc1 = paddle.nn.Linear(8, 8) | ||
... self.fc2 = paddle.nn.Linear(8, 8) | ||
... | ||
... def forward(self, input): | ||
... return self.fc2(self.fc1(input)) | ||
|
||
>>> def shard_fn(layer_name, layer, process_mesh): | ||
... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None]) | ||
... if layer_name == 'fc1': | ||
... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr) | ||
|
||
>>> layer = MLP() | ||
>>> layer = dist.shard_layer(layer, mesh, shard_fn) | ||
>>> print(layer) | ||
|
||
>>> # This case need to be excuted in multi-card environment | ||
>>> # export CUDA_VISIBLE_DEVICES=0,1 | ||
>>> # python -m paddle.distributed.launch {test_case}.py | ||
""" | ||
# Ensure that process_mesh is not an empty object | ||
if process_mesh is None: | ||
raise ValueError("The argument `process_mesh` cannot be empty.") | ||
|
||
# Check the legality of process_mesh | ||
if not isinstance(process_mesh, dist.ProcessMesh): | ||
raise ValueError( | ||
"The argument `process_mesh` is not `dist.ProcessMesh` type." | ||
) | ||
|
||
def replicate_layer_params_and_buffers( | ||
layer: nn.Layer, mesh: dist.ProcessMesh | ||
) -> None: | ||
for key, param in layer._parameters.items(): | ||
if param is not None and not param.is_dist(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果参数不是dist_tensor,就把它变成replicated的,如果我们后面再执行机制里加了转replicated的操作,这里还需要吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 后面API都是inplace转replicated的话,这里应该可以删除,效果一致。 不过也不冲突,前置转其实逻辑上更顺畅一些,inplace转多少还是隐式更改了用户的输入 |
||
replicated_dist_attr = dist.DistAttr( | ||
mesh=mesh, | ||
sharding_specs=[None for _ in range(len(param.shape))], | ||
) | ||
layer.add_parameter( | ||
key, | ||
shard_tensor(param, dist_attr=replicated_dist_attr), | ||
) | ||
else: | ||
# do nothing, the dist parameters has already been shard by shard_fn | ||
pass | ||
for key, buffer in layer._buffers.items(): | ||
if buffer is not None and not buffer.is_dist(): | ||
replicated_dist_attr = dist.DistAttr( | ||
mesh=mesh, | ||
sharding_specs=[None for _ in range(len(buffer.shape))], | ||
) | ||
layer.register_buffer( | ||
key, | ||
shard_tensor(buffer, dist_attr=replicated_dist_attr), | ||
) | ||
else: | ||
# do nothing, the dist buffers has already been shard by shard_fn | ||
pass | ||
|
||
if paddle.in_dynamic_mode(): | ||
if shard_fn is None: | ||
# if shard_fn not specified, by default replicate | ||
# all layer's parameters and buffers | ||
for name, sublayers in layer.named_sublayers(include_self=True): | ||
replicate_layer_params_and_buffers(sublayers, process_mesh) | ||
else: | ||
# apply shard_fn to sublayers, contains self | ||
for name, sublayers in layer.named_sublayers(include_self=True): | ||
shard_fn(name, sublayers, process_mesh) | ||
# shard_fn may not deal with all parameters and buffers, | ||
# the parameters and buffers that are not shard by shard_fn | ||
# still need to be shard to replicated | ||
replicate_layer_params_and_buffers(sublayers, process_mesh) | ||
|
||
# register input_fn as layer's forward pre hook | ||
if input_fn is not None: | ||
layer.register_forward_pre_hook( | ||
lambda _, inputs: input_fn(inputs, process_mesh) | ||
) | ||
# register output_fn as layer's forward post hook | ||
if output_fn is not None: | ||
layer.register_forward_post_hook( | ||
lambda _, inputs, outputs: output_fn(outputs, process_mesh) | ||
) | ||
|
||
return layer | ||
else: | ||
# TODO(chenweihang): Support static mode branch later. | ||
raise NotImplementedError( | ||
"`paddle.distributed.shard_layer` only supports dynamic graph mode " | ||
"now. It will be supported for static graph mode later." | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle import nn | ||
|
||
|
||
# TODO(chenweihang): test for paddle nn Layer API | ||
class DemoLayer(nn.Layer): | ||
def __init__(self, num_features): | ||
super().__init__() | ||
self.w0 = self.create_parameter(shape=[num_features, num_features]) | ||
self.w1 = self.create_parameter(shape=[num_features, num_features]) | ||
|
||
def forward(self, x): | ||
y = paddle.matmul(x, self.w0) | ||
z = paddle.matmul(y, self.w1) | ||
return z | ||
|
||
|
||
class MyLayer(nn.Layer): | ||
def __init__(self, num_features, num_layers): | ||
super().__init__() | ||
self.seq = nn.Sequential( | ||
*[DemoLayer(num_features) for _ in range(num_layers)] | ||
) | ||
|
||
def forward(self, x): | ||
return self.seq(x) | ||
|
||
|
||
class TestShardLayer(unittest.TestCase): | ||
def setUp(self): | ||
self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个单测,没有用launch启,但是可以用两卡吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里没实际用mesh切分,下面的spec都是None,运行时tensor都是replicated,切分会跳过,所以单卡也能跑,主要是测了一遍流程 |
||
self.num_features = 10 | ||
self.num_layers = 10 | ||
|
||
def test_shard_layer_base(self): | ||
layer = MyLayer(self.num_features, self.num_layers) | ||
|
||
def shard_fn(layer_name, layer, process_mesh): | ||
if isinstance(layer, nn.Linear): | ||
for name, param in layer.named_parameters(): | ||
if 'weight' in name: | ||
dist_param = dist.shard_tensor( | ||
param, | ||
dist_attr=dist.DistAttr( | ||
mesh=process_mesh, sharding_specs=[None, None] | ||
), | ||
) | ||
else: | ||
dist_param = dist.shard_tensor( | ||
param, | ||
dist_attr=dist.DistAttr( | ||
mesh=process_mesh, sharding_specs=[None] | ||
), | ||
) | ||
layer.add_parameter(name, dist_param) | ||
|
||
# test shard parameters | ||
sharded_params_layer = dist.shard_layer(layer, self.mesh, shard_fn) | ||
|
||
for param in sharded_params_layer.parameters(): | ||
self.assertTrue(param.is_dist()) | ||
for x in param.dist_attr.dims_mapping: | ||
self.assertEqual(x, -1) | ||
|
||
# test shard buffers | ||
test_buffer = paddle.randn([10]) | ||
layer.register_buffer("test_buffer", test_buffer, persistable=True) | ||
sharded_buffers_layer = dist.shard_layer(layer, self.mesh, shard_fn) | ||
self.assertTrue(sharded_buffers_layer.test_buffer.is_dist()) | ||
self.assertEqual( | ||
sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1] | ||
) | ||
|
||
def test_shard_layer_input_fn_and_output_fn(self): | ||
layer = MyLayer(self.num_features, self.num_layers) | ||
|
||
def input_fn(inputs, process_mesh): | ||
return dist.shard_tensor( | ||
inputs[0], dist_attr=dist.DistAttr(process_mesh, [None, None]) | ||
) | ||
|
||
def output_fn(outputs, process_mesh): | ||
assert outputs.is_dist() | ||
# TODO(chenweihang): replace by dist.unshard_dtensor later | ||
return paddle.to_tensor(outputs.numpy()) | ||
|
||
# test shard parameters | ||
replicate_params_layer = dist.shard_layer( | ||
layer, self.mesh, input_fn=input_fn, output_fn=output_fn | ||
) | ||
|
||
x = paddle.randn([5, self.num_features]) | ||
dense_out = replicate_params_layer(x) | ||
self.assertTrue(dense_out.is_dense()) | ||
|
||
for param in replicate_params_layer.parameters(): | ||
self.assertTrue(param.is_dist()) | ||
for x in param.dist_attr.dims_mapping: | ||
self.assertEqual(x, -1) | ||
|
||
# test shard buffers | ||
test_buffer = paddle.randn([10]) | ||
layer.register_buffer("test_buffer", test_buffer, persistable=True) | ||
sharded_buffers_layer = dist.shard_layer( | ||
layer, self.mesh, input_fn=input_fn, output_fn=output_fn | ||
) | ||
self.assertTrue(sharded_buffers_layer.test_buffer.is_dist()) | ||
self.assertEqual( | ||
sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1] | ||
) | ||
|
||
def test_process_mesh_argument_error(self): | ||
layer = MyLayer(self.num_features, self.num_layers) | ||
|
||
exception = None | ||
try: | ||
dist.shard_layer(layer, None) | ||
except ValueError as ex: | ||
self.assertIn( | ||
"The argument `process_mesh` cannot be empty", | ||
str(ex), | ||
) | ||
exception = ex | ||
self.assertIsNotNone(exception) | ||
|
||
exception = None | ||
try: | ||
dist_attr = dist.DistAttr( | ||
mesh=self.mesh, sharding_specs=[None, None] | ||
) | ||
dist.shard_layer(layer, dist_attr) | ||
except ValueError as ex: | ||
self.assertIn( | ||
"The argument `process_mesh` is not `dist.ProcessMesh` type", | ||
str(ex), | ||
) | ||
exception = ex | ||
self.assertIsNotNone(exception) | ||
|
||
def test_shard_layer_static_mode(self): | ||
paddle.enable_static() | ||
layer = MyLayer(self.num_features, self.num_layers) | ||
|
||
exception = None | ||
try: | ||
dist.shard_layer(layer, self.mesh) | ||
except NotImplementedError as ex: | ||
self.assertIn( | ||
"`paddle.distributed.shard_layer` only supports dynamic graph mode now", | ||
str(ex), | ||
) | ||
exception = ex | ||
self.assertIsNotNone(exception) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_layer
需要支持跨mesh的情况吗?不支持的话,replicate_layer_params_and_buffers
需不需要检查Tensor
的ProcessMesh
是否合法There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
理论上不应该限制是否跨mesh,不过当用户跨mesh shard_layer的时候,就不需要调用reshard了,这块在用户使用行为上还需要讨论下
不过,这里不需要检查mesh具体的状态,因为shard_layer是shard_tensor的包装,要检查的话在shard_tensor中检查更加合适,shard_tensor检查不合法就报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
明白了~