Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Add Strategy api for configuring the distributed training with static graph #59862

Merged
merged 7 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
shard_layer,
shard_optimizer,
to_static,
Strategy,
)

from .fleet import BoxPSDataset # noqa: F401
Expand Down Expand Up @@ -165,4 +166,5 @@
"load_state_dict",
"shard_optimizer",
"to_static",
"Strategy",
]
269 changes: 249 additions & 20 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import defaultdict
from typing import Callable

Expand All @@ -23,7 +24,7 @@
Variable,
default_main_program,
)
from paddle.distributed.auto_parallel import Engine
from paddle.distributed.auto_parallel import Engine, strategy as auto_strategy
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
Expand Down Expand Up @@ -108,9 +109,11 @@ def sharding_specs(self):

class DistModel:
"""
DistModel is a wrapper of the network model for the use of static mode
auto parallel. DistModel contains the distributed Graph of the model and
offers the APIs for training, evaluation and prediction.
DistModel is generated by ``paddle.distributed.to_static``. It contains the
static graph converted from a ``paddle.nn.layer`` whose parameters are
distributed tensors (constructed from ``paddle.distributed.shard_tensor``),
and provides the APIs for training, evaluation and prediction with the
static graph.

Please first set the DistModel to "train", "eval" or "predict" mode and
then use the __call__ method for training, evaluation and prediction
Expand All @@ -127,8 +130,8 @@ class DistModel:
to "eval" mode in default. If loss and optimizer are both None, DistModel
will be set to "predict" mode in default.

DistModel is generated by ``paddle.distributed.to_static``, for more details
of the usage, please refer to the sample code in ``paddle.distributed.to_static``.
For more details of the usage, please refer to the sample code in
``paddle.distributed.to_static``.
"""

def __init__(
Expand All @@ -141,8 +144,9 @@ def __init__(
metrics=None,
):
self._feed_name_list = []
self._inner_strategy = self.__convert_strategy(strategy)
self._engine = Engine(
layer, loss, optimizer, metrics, strategy=strategy
layer, loss, optimizer, metrics, strategy=self._inner_strategy
)
self._mode = None
self._feed_name_list = {}
Expand Down Expand Up @@ -271,6 +275,27 @@ def _make_feeds(self, data_list):
)
return dict(zip(feed_name_list, data_list))

def __convert_strategy(self, strategy):
import copy

if strategy is None:
return None
inner_strategy = auto_strategy.Strategy()
inner_strategy.fused_passes.enable = strategy.fused_passes.enable
if strategy.fused_passes.gemm_epilogue is True:
inner_strategy.fused_passes.fused_passes_list.append(
"fused_gemm_epilogue_pass"
)
if strategy.fused_passes.dropout_add is True:
inner_strategy.fused_passes.fused_passes_list.append(
"fused_dropout_add_pass"
)

inner_strategy.sharding = copy.deepcopy(strategy.sharding)
inner_strategy.gradient_merge = copy.deepcopy(strategy.gradient_merge)
inner_strategy.pipeline = copy.deepcopy(strategy.pipeline)
return inner_strategy

def __call__(self, *args):
if self._mode is None:
raise ValueError("Please call train()/eval()/predict() first.")
Expand Down Expand Up @@ -298,6 +323,209 @@ def __call__(self, *args):


# Part2: DistTensor construction related APIs


class FusePasses:
"""
A helper class for users to configure the fuse passes.
"""

def __init__(self, config_dict=None):
self.enable = False
self.gemm_epilogue = False
self.dropout_add = False
if config_dict is not None:
for key, value in config_dict.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"Unknown fuse pass {key}")


class Strategy(auto_strategy.BaseConfig):
"""
The `Strategy` object is used to configure the parallelization
and optimization strategies for static graph. Currently contains
configuring ``sharding``, ``fused_passes``, ``gradient_merge``
and ``pipline``. More strategies will be supported in the future.

``sharding`` is used to cnofigure the sharding states of the optimizer,
for saving the GPU memory.

``fused_passes`` is used to configure the fusion of the computation in
the model.

``gradient_merge`` is used to configure the gradient merge strategy in
training.

``pipeline`` is used to configure the pipeline parallelism strategy.

Args:
config (dict|None, optional): If ``config`` is None, the default
configurations will be set. If it is a dict, the itmes inside
the dict will be used to set the configurations, the others remain
the default values.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.distributed as dist

>>> strategy = dist.Strategy()

>>> strategy.sharding.enable = True
>>> strategy.sharding.stage = 2
>>> strategy.sharding.degree = 2

>>> strategy.gradient_merge.enable = True
>>> strategy.gradient_merge.k_steps = 2
>>> strategy.gradient_merge.avg = False

>>> strategy.pipeline.enable = True
>>> strategy.pipeline.schedule_mode = "1F1B" # default is "1F1B"
>>> strategy.pipeline.micro_batch_size = 2
"""

def __init__(self, config=None):
if config is not None:
if isinstance(config, dict):
self._config_dict = copy.deepcopy(config)
else:
raise ValueError(
f"Expected a dictionary. But received: {config}"
)
else:
self._config_dict = {}

category = auto_strategy.constants.BASE
super().__init__(category, self._config_dict)

config_dict = self._config_dict.get(
auto_strategy.constants.SHARDING, None
)
self._sharding = auto_strategy.ShardingConfig(config_dict)

config_dict = self._config_dict.get(
auto_strategy.constants.GRADIENT_MERGE, None
)
self._gradient_merge = auto_strategy.GradientMergeConfig(config_dict)

config_dict = self._config_dict.get(
auto_strategy.constants.PIPELINE, None
)
self._pipeline = auto_strategy.PipelineConfig(config_dict)

config_dict = self._config_dict.get(
auto_strategy.constants.FUSED_PASSES, None
)
self._fused_passes = FusePasses(config_dict)

@property
def sharding(self):
"""
``sharding`` is used to cnofigure the sharding states of the optimizer,
containing following configs:

``enable`` (bool): whether to enable sharding. Default: False.

``stage`` (int): can be set to 1, 2 or 3. 1 indicates the optimizer states segmentation,
2 indicates optimizer states and gradient segmentation, 3 indicates the segmentation
of optimizer states, gradient and parameters. Default: 1.

``degree`` (int): the number of segmentation pieces. Default: 8.

Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist

>>> strategy = dist.Strategy()

>>> strategy.sharding.enable = True
>>> strategy.sharding.stage = 2
>>> strategy.sharding.degree = 2
"""
return self._sharding

@property
def gradient_merge(self):
"""
``gradient_merge`` is used to configure the gradient merge strategy in
training, containing following configs:

``enable`` (bool): whether to enable gradient merge. Default: False.

``k_steps`` (int): the number of steps for merging gradients. Default: 1.

``avg`` (bool): whether to average the gradients of each step. Default: True.

Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist

>>> strategy = dist.Strategy()

>>> strategy.gradient_merge.enable = True
>>> strategy.gradient_merge.k_steps = 2
>>> strategy.gradient_merge.avg = True
"""
return self._gradient_merge

@property
def fused_passes(self):
"""
``fused_passes`` is used to configure the fusion of the computation in
the model, containing following configs:

``enable`` (bool): whether to enable fused passes. Default: False.

``gemm_epilogue`` (bool): whether to fuse ``matmul`` and ``add`` computation
in the ``Linear`` layer. Default: False

"dropout_add" (bool): whether to fuse ``dropout`` and ``add`` computation. Default: False.

Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist

>>> strategy = dist.Strategy()

>>> strategy.fused_passes.enable = True
>>> strategy.fused_passes.gemm_spilogue = True
>>> strategy.fused_passes.dropout_add = True
"""
return self._fused_passes

@property
def pipeline(self):
"""
``pipeline`` is used to configure the pipeline parallelism in training,
containing following configs:

``enable`` (bool): whether to enable pipeline parallelism. Default: False.

``schedule_mode`` (str): the scheduling mode of pipeline parallelism. Default: "1F1B".

``micro_batch_size`` (int): the size of each micro-batch inside a mini-batch. Default: 1.

``accumulate_steps`` (int): number of steps for accumulating. Default: 1.

Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist

>>> strategy = dist.Strategy()

>>> strategy.pipeline.enable = True
>>> strategy.pipeline.micro_batch_size = 2
"""
return self._pipeline


def to_static(
layer: paddle.nn.Layer,
loader=None,
Expand All @@ -306,29 +534,30 @@ def to_static(
strategy=None,
):
"""
Converts the model and data loader used in dygraph auto-parallelism to
that in static mode auto-parallelism. to_static returns a DistModel
instance that provides APIs and a DistributedDataLoader to generate data
for static mode auto-parallel training, evaluation and prediction.
Converts the ``layer`` with distributed tensor (constructed from
``paddle.distributed.shard_tensor``) to a static graph. to_static
returns a DistModel instance containing the static graph for
distributed training, evaluation and prediction, and an object of
DistributedDataLoader to generate data.

Args:
layer(paddle.nn.Layer): The layer in dygraph model, the parameters
or its inputs can be sharded.
loader(paddle.io.DataLoader): The data loader used in dygraph model,
used to generate Distributed Dataloader for static auto parallel.
layer(paddle.nn.Layer): The layer in dygraph mode, the parameters
or its inputs can be distributed tensors.
loader(paddle.io.DataLoader): The data loader used in dygraph mode,
used to generate DistributedDataloader.
loss(Loss|Callable|None, optional): The loss function for training
or evaluating the model. Can be a `paddle.nn.Layer` instance or
any callable function. Default: None.
optimizer(paddle.optimizer.Optimizer|None, optional): The optimizer
for training. Default: None.
strategy(Strategy|None, optional): Configs for parallel strategies
(e.g. data parallel, hybrid parallel etc.) and optimization
settings (e.g. mixed-precision). Default: None.
strategy(paddle.distributed.Strategy|None, optional): Configs for
parallel strategies and optimization settings (e.g. sharding,
pipeline parallelism). Default: None.

Returns:
DistModel: A DistModel tha contains corresponding computational graph
for the input layer and provides APIs for training, evaluation and
prediction.
for the input ``layer`` and provides APIs for training, evaluation
and prediction.
DistributedDataLoader: An optimized data loader that can be used
to generate data.

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(self, config_dict=None):

class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the parallelization and optimization behaviors.
The `Strategy` object is used to configure the parallelization and optimization for static graph.

Args:
config (dict|string, optional): If this is None, the default configurations will used.
Expand Down
9 changes: 8 additions & 1 deletion test/auto_parallel/hybrid_strategy/semi_auto_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,15 @@ def run_dy2static(self):
else:
opt = optimizer

strategy = None
if self.gradient_accumulation_steps > 1:
strategy = dist.Strategy()
strategy.pipeline.accumulate_steps = (
self.gradient_accumulation_steps
)

dist_model, dist_loader = dist.to_static(
model, train_dataloader, criterion, opt
model, train_dataloader, criterion, opt, strategy=strategy
)

dist_model.train()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_simple_net_hybrid_strategy(self):
class TestSemiAutoParallelLlama3D(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=8, timeout=200, nnode=1)
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "1"}
self._default_envs = {"dp": "2", "mp": "2", "pp": "2", "acc_step": "2"}
self._changeable_envs = {
"backend": ["gpu"],
"use_sp": ["true", "false"],
Expand Down
3 changes: 2 additions & 1 deletion test/auto_parallel/semi_auto_parallel_dist_to_static_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def run_test(self):
loss_fn = nn.MSELoss()

# static training
strategy = dist.Strategy()
dist_model, dist_loader = dist.to_static(
layer, self.data_loader, loss_fn, opt
layer, self.data_loader, loss_fn, opt, strategy=strategy
)

dist_model._mode = None
Expand Down
Loading