diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index cf7c20eacaee1..7f641a5e6fa54 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -68,6 +68,7 @@ from .auto_parallel.api import shard_tensor # noqa: F401 from .auto_parallel.api import dtensor_from_fn # noqa: F401 from .auto_parallel.api import reshard # noqa: F401 +from .auto_parallel.api import shard_layer # noqa: F401 from .fleet import BoxPSDataset # noqa: F401 @@ -130,4 +131,5 @@ "shard_tensor", "dtensor_from_fn", "reshard", + "shard_layer", ] diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 567998aeeab7e..54ef7711620bf 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable import paddle +import paddle.distributed as dist +from paddle import nn from paddle.base.framework import EagerParamBase from paddle.distributed.auto_parallel.interface import ( shard_tensor as shard_tensor_static, @@ -24,6 +27,8 @@ # Some APIs have the same name with the previous APIs implementation, which are # a temporary state, and the APIs here will eventually be used. +# Part1: Shard attributes related APIs + class DistAttr(core.TensorDistAttr): """ @@ -83,6 +88,9 @@ def sharding_specs(self): return self._sharding_specs +# Part2: DistTensor construction related APIs + + def shard_tensor( data, dtype=None, place=None, stop_gradient=True, dist_attr=None ): @@ -180,6 +188,9 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs): return shard_tensor(tensor, dist_attr=dist_attr) +# Part3: Data conversion related APIs + + def reshard(dist_tensor, dist_attr): """ Reshard a distributed ``paddle.Tensor`` with given distributed attributes. @@ -223,3 +234,156 @@ def reshard(dist_tensor, dist_attr): raise RuntimeError( "paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later." ) + + +def shard_layer( + layer: nn.Layer, + process_mesh: dist.ProcessMesh, + shard_fn: Callable = None, + input_fn: Callable = None, + output_fn: Callable = None, +) -> nn.Layer: + """ + Converts all layer's parameters to DistTensor parameters according to + the `shard_fn` specified. It could also control the conversion of input + or output of the layer by specifying the `input_fn` and `output_fn`. + (i.e. convert the input to `paddle.Tensor` with DistTensor, convert output + back to `paddle.Tensor` with DenseTensor.) + + The `shard_fn` should have the following signature: + + def shard_fn(layer_name, layer, process_mesh) -> None + + The `input_fn` should have the following signature: + + def input_fn(inputs, process_mesh) -> list(paddle.Tensor) + + In general, the type of `input_fn` return value is paddle.Tensor with DistTensor. + + The `output_fn` should have the following signature: + + def output_fn(outputs, process_mesh) -> list(paddle.Tensor) + + In general, the type of `output_fn` return value is paddle.Tensor with DenseTensor. + + Args: + layer (paddle.nn.Layer): The Layer object to be shard. + process_mesh (paddle.distributed.ProcessMesh): The `ProcessMesh` information + to be place the input `layer`. + shard_fn (Callable): The function to shard layer parameters across + the `process_mesh`. If not specified, by default we replicate + all parameters of the layer across the `process_mesh`. + input_fn (Callable): Specify how the input of the layer is sharded. + The `input_fn` will be registered for the Layer as a `forward pre-hook`. + By default we do not shard the input. + output_fn (Callable): Specify how the output of the layer is sharded or + convert it back to `paddle.Tensor` with DenseTensor. + The `output_fn` will be registered for the Layer as `forward post-hook`. + By default we do not shard or convert the output. + Returns: + Layer: A layer that contains parameters/buffers + that are all `paddle.Tensor` with DistTensor + + Examples: + .. code-block:: python + + >>> import paddle + >>> import paddle.distributed as dist + + >>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + >>> class MLP(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.fc1 = paddle.nn.Linear(8, 8) + ... self.fc2 = paddle.nn.Linear(8, 8) + ... + ... def forward(self, input): + ... return self.fc2(self.fc1(input)) + + >>> def shard_fn(layer_name, layer, process_mesh): + ... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None]) + ... if layer_name == 'fc1': + ... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr) + + >>> layer = MLP() + >>> layer = dist.shard_layer(layer, mesh, shard_fn) + >>> print(layer) + + >>> # This case need to be excuted in multi-card environment + >>> # export CUDA_VISIBLE_DEVICES=0,1 + >>> # python -m paddle.distributed.launch {test_case}.py + """ + # Ensure that process_mesh is not an empty object + if process_mesh is None: + raise ValueError("The argument `process_mesh` cannot be empty.") + + # Check the legality of process_mesh + if not isinstance(process_mesh, dist.ProcessMesh): + raise ValueError( + "The argument `process_mesh` is not `dist.ProcessMesh` type." + ) + + def replicate_layer_params_and_buffers( + layer: nn.Layer, mesh: dist.ProcessMesh + ) -> None: + for key, param in layer._parameters.items(): + if param is not None and not param.is_dist(): + replicated_dist_attr = dist.DistAttr( + mesh=mesh, + sharding_specs=[None for _ in range(len(param.shape))], + ) + layer.add_parameter( + key, + shard_tensor(param, dist_attr=replicated_dist_attr), + ) + else: + # do nothing, the dist parameters has already been shard by shard_fn + pass + for key, buffer in layer._buffers.items(): + if buffer is not None and not buffer.is_dist(): + replicated_dist_attr = dist.DistAttr( + mesh=mesh, + sharding_specs=[None for _ in range(len(buffer.shape))], + ) + layer.register_buffer( + key, + shard_tensor(buffer, dist_attr=replicated_dist_attr), + ) + else: + # do nothing, the dist buffers has already been shard by shard_fn + pass + + if paddle.in_dynamic_mode(): + if shard_fn is None: + # if shard_fn not specified, by default replicate + # all layer's parameters and buffers + for name, sublayers in layer.named_sublayers(include_self=True): + replicate_layer_params_and_buffers(sublayers, process_mesh) + else: + # apply shard_fn to sublayers, contains self + for name, sublayers in layer.named_sublayers(include_self=True): + shard_fn(name, sublayers, process_mesh) + # shard_fn may not deal with all parameters and buffers, + # the parameters and buffers that are not shard by shard_fn + # still need to be shard to replicated + replicate_layer_params_and_buffers(sublayers, process_mesh) + + # register input_fn as layer's forward pre hook + if input_fn is not None: + layer.register_forward_pre_hook( + lambda _, inputs: input_fn(inputs, process_mesh) + ) + # register output_fn as layer's forward post hook + if output_fn is not None: + layer.register_forward_post_hook( + lambda _, inputs, outputs: output_fn(outputs, process_mesh) + ) + + return layer + else: + # TODO(chenweihang): Support static mode branch later. + raise NotImplementedError( + "`paddle.distributed.shard_layer` only supports dynamic graph mode " + "now. It will be supported for static graph mode later." + ) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 873f9f057e9ab..cb1ad0e8bef02 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -194,6 +194,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_dist_tensor MODULES test_dist_tensor) py_test_modules(test_api_dist_branch MODULES test_api_dist_branch) py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api) + py_test_modules(test_shard_layer_api MODULES test_shard_layer_api) py_test_modules(test_cost_interface MODULES test_cost_interface) # End of unittests WITH single card WITHOUT timeout diff --git a/test/auto_parallel/test_shard_layer_api.py b/test/auto_parallel/test_shard_layer_api.py new file mode 100644 index 0000000000000..79ce4e95d37c0 --- /dev/null +++ b/test/auto_parallel/test_shard_layer_api.py @@ -0,0 +1,174 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.distributed as dist +from paddle import nn + + +# TODO(chenweihang): test for paddle nn Layer API +class DemoLayer(nn.Layer): + def __init__(self, num_features): + super().__init__() + self.w0 = self.create_parameter(shape=[num_features, num_features]) + self.w1 = self.create_parameter(shape=[num_features, num_features]) + + def forward(self, x): + y = paddle.matmul(x, self.w0) + z = paddle.matmul(y, self.w1) + return z + + +class MyLayer(nn.Layer): + def __init__(self, num_features, num_layers): + super().__init__() + self.seq = nn.Sequential( + *[DemoLayer(num_features) for _ in range(num_layers)] + ) + + def forward(self, x): + return self.seq(x) + + +class TestShardLayer(unittest.TestCase): + def setUp(self): + self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self.num_features = 10 + self.num_layers = 10 + + def test_shard_layer_base(self): + layer = MyLayer(self.num_features, self.num_layers) + + def shard_fn(layer_name, layer, process_mesh): + if isinstance(layer, nn.Linear): + for name, param in layer.named_parameters(): + if 'weight' in name: + dist_param = dist.shard_tensor( + param, + dist_attr=dist.DistAttr( + mesh=process_mesh, sharding_specs=[None, None] + ), + ) + else: + dist_param = dist.shard_tensor( + param, + dist_attr=dist.DistAttr( + mesh=process_mesh, sharding_specs=[None] + ), + ) + layer.add_parameter(name, dist_param) + + # test shard parameters + sharded_params_layer = dist.shard_layer(layer, self.mesh, shard_fn) + + for param in sharded_params_layer.parameters(): + self.assertTrue(param.is_dist()) + for x in param.dist_attr.dims_mapping: + self.assertEqual(x, -1) + + # test shard buffers + test_buffer = paddle.randn([10]) + layer.register_buffer("test_buffer", test_buffer, persistable=True) + sharded_buffers_layer = dist.shard_layer(layer, self.mesh, shard_fn) + self.assertTrue(sharded_buffers_layer.test_buffer.is_dist()) + self.assertEqual( + sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1] + ) + + def test_shard_layer_input_fn_and_output_fn(self): + layer = MyLayer(self.num_features, self.num_layers) + + def input_fn(inputs, process_mesh): + return dist.shard_tensor( + inputs[0], dist_attr=dist.DistAttr(process_mesh, [None, None]) + ) + + def output_fn(outputs, process_mesh): + assert outputs.is_dist() + # TODO(chenweihang): replace by dist.unshard_dtensor later + return paddle.to_tensor(outputs.numpy()) + + # test shard parameters + replicate_params_layer = dist.shard_layer( + layer, self.mesh, input_fn=input_fn, output_fn=output_fn + ) + + x = paddle.randn([5, self.num_features]) + dense_out = replicate_params_layer(x) + self.assertTrue(dense_out.is_dense()) + + for param in replicate_params_layer.parameters(): + self.assertTrue(param.is_dist()) + for x in param.dist_attr.dims_mapping: + self.assertEqual(x, -1) + + # test shard buffers + test_buffer = paddle.randn([10]) + layer.register_buffer("test_buffer", test_buffer, persistable=True) + sharded_buffers_layer = dist.shard_layer( + layer, self.mesh, input_fn=input_fn, output_fn=output_fn + ) + self.assertTrue(sharded_buffers_layer.test_buffer.is_dist()) + self.assertEqual( + sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1] + ) + + def test_process_mesh_argument_error(self): + layer = MyLayer(self.num_features, self.num_layers) + + exception = None + try: + dist.shard_layer(layer, None) + except ValueError as ex: + self.assertIn( + "The argument `process_mesh` cannot be empty", + str(ex), + ) + exception = ex + self.assertIsNotNone(exception) + + exception = None + try: + dist_attr = dist.DistAttr( + mesh=self.mesh, sharding_specs=[None, None] + ) + dist.shard_layer(layer, dist_attr) + except ValueError as ex: + self.assertIn( + "The argument `process_mesh` is not `dist.ProcessMesh` type", + str(ex), + ) + exception = ex + self.assertIsNotNone(exception) + + def test_shard_layer_static_mode(self): + paddle.enable_static() + layer = MyLayer(self.num_features, self.num_layers) + + exception = None + try: + dist.shard_layer(layer, self.mesh) + except NotImplementedError as ex: + self.assertIn( + "`paddle.distributed.shard_layer` only supports dynamic graph mode now", + str(ex), + ) + exception = ex + self.assertIsNotNone(exception) + + +if __name__ == '__main__': + unittest.main()