Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharding optim] Optim check finite #55766

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,17 @@ def __init__(self, optimizer, hcg):
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()

if not self.tensor_fusion:
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
if not self.tensor_fusion and not self.comm_overlap:
local_params = self._rank2params[self._sharding_rank]
self._set_inner_opt_attr('_parameter_list', local_params)
self._set_inner_opt_attr('_param_groups', local_params)
else:
self._tensor_fusion()

decay_params = [
p.name for p in self._rank2decay[self._sharding_rank]
]
fused_params = self._rank2fused[self._sharding_rank]
local_fused_params = self._rank2fused[self._sharding_rank]
apply_decay_param_fun = lambda x: x in decay_params

all_fused_params = []
Expand All @@ -118,15 +115,30 @@ def __init__(self, optimizer, hcg):
self._parameter_list = all_fused_params
self._param_groups = all_fused_params

self._set_inner_opt_attr('_parameter_list', fused_params)
self._set_inner_opt_attr('_param_groups', fused_params)
self._set_inner_opt_attr('_parameter_list', local_fused_params)
self._set_inner_opt_attr('_param_groups', local_fused_params)
if self.comm_overlap:
# Only set local param for check finite when comm overlap.
# Under comm overlap, all grads will be communicated before check_finite.
# Therefore, each sharding rank can get all grads' info at check_finite.
# Without comm overlap, all grads will be communicated after check_finite,
# which means each sharding rank should do check_finite to all grads.
self._local_parameter_list = local_fused_params
origin_decay_param_fun = getattr(
self._inner_opt, '_apply_decay_param_fun', None
)
if origin_decay_param_fun is not None:
self._set_inner_opt_attr(
'_apply_decay_param_fun', apply_decay_param_fun
)
# Note: during the tensor fusion for parameters, the allocator will apply for
# some extra GPU memory for the fused big paramters. This extra GPU memory will
# be useless at once the fusion has done. But the Paddle's allocator won't
# release those memory, it will hold that part in the memory poll. So after
# tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
# won't change. To avoid failure on some other applications (such as some nvtx
# operations), here we manulay let the allocator release the cached memory.
paddle.device.cuda.empty_cache()

def clear_grad(self, set_to_zero=True):
"""
Expand Down
29 changes: 22 additions & 7 deletions python/paddle/distributed/fleet/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,35 @@ def unscale_method(self, optimizer):
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar()
for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
strategy = fleet.fleet._user_defined_strategy
sharding_stage_1_overlap = strategy.hybrid_configs[
'sharding_configs'
].comm_overlap
if sharding_stage_1_overlap:
# If sharding stage 1 enable comm overlap and need do loss scale. Here we have to wait all comm tasks.
# If no need do loss scale, the wait for all comm tasks will do in the optimizer step.
assert hasattr(optimizer, "_comm_buffers")
assert hasattr(optimizer, "_sharding_enable")
if optimizer._sharding_enable:
# disable origin grad reduce in hybrid optimizer step
optimizer._sharding_enable = False
for buffer in optimizer._comm_buffers:
buffer.scale_grads()
# For sharding stage 1 under comm overlap, each rank only have to check finite for the response params.
# For now, only sharding stage 1 contains this attr, this can be promoted to stage 2 and stage 3.
assert hasattr(optimizer, "_local_parameter_list")
parameters = optimizer._local_parameter_list
else:
parameters = optimizer._parameter_list
param_grads_fp16 = [
param._grad_ivar()
for param in optimizer._parameter_list
for param in parameters
if (param._grad_ivar() is not None)
and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
]
param_grads_fp32 = [
param._grad_ivar()
for param in optimizer._parameter_list
for param in parameters
if (param._grad_ivar() is not None)
and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.distributed import fleet

vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10


class SimpleDPNet(paddle.nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size):
super().__init__()
self.linear1 = paddle.nn.Linear(hidden_size, inner_size)

self.linear2 = paddle.nn.Linear(inner_size, hidden_size)

self.linear3 = paddle.nn.Linear(hidden_size, output_size)

self.embedding = paddle.nn.Embedding(vocab_size, hidden_size)

def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x


class TestDistSharding(unittest.TestCase):
def setUp(self):
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True
self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True
self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1
fleet.init(is_collective=True, strategy=self.strategy)
self.data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)

if paddle.distributed.get_rank() == 0:
self.batch_sharding = paddle.to_tensor(self.data[:2])
else:
self.batch_sharding = paddle.to_tensor(self.data[2:])

def build_optimizer(self, model):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.001,
grad_clip=clip,
)
return optimizer

def build_model_optimizer(self):
model = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size)
optimizer = self.build_optimizer(model)
model, optimizer = paddle.amp.decorate(
model, optimizers=optimizer, level="O2", dtype="float16"
)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
return model, optimizer, scaler

def sharding_model(self):
model, optimizer, scaler = self.build_model_optimizer()

for idx in range(STEPS):
with paddle.amp.auto_cast(enable=True, level='O2'):
output = model(self.batch_sharding)
loss = output.mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()

def test_sharding_adam(self):
self.sharding_model()


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def test_hybrid_parallel_sharding_logic(self):
def test_hybrid_parallel_sharding_tensor_fusion(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py')

def test_hybrid_parallel_sharding_tensor_fusion_amp(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion_amp.py')

def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')

Expand Down