Skip to content

Commit

Permalink
[Optimizer] Add master weight for opt state_dict (#39121)
Browse files Browse the repository at this point in the history
* add master weight for opt state_dict

* check empty of master weight

* strict gpu test

* refine unittest
  • Loading branch information
zhangbo9674 authored Jan 27, 2022
1 parent 80dfa01 commit 3e6950d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
87 changes: 87 additions & 0 deletions python/paddle/fluid/tests/unittests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard, convert_np_dtype_to_dtype_
import paddle
from paddle.io import Dataset
import numpy
paddle.enable_static()


Expand Down Expand Up @@ -1113,5 +1115,90 @@ def test_float32(self):
self.check_with_dtype('float32')


class TestMasterWeightSaveForFP16(unittest.TestCase):
'''
For Amp-O2, some optimizer(Momentum, Adam ...) will create master weights for parameters to to improve the accuracy.
Master weights will be saved by optimizer::state_dict.
'''

def check_with_opt_state_dict(self, use_save_load=True):
paddle.seed(100)
numpy.random.seed(100)

class SimpleNet(paddle.nn.Layer):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linears = paddle.nn.LayerList([
paddle.nn.Linear(input_size, output_size) for i in range(1)
])

def forward(self, x):
for i, l in enumerate(self.linears):
x = self.linears[i](x)
return x

input_size = 2 # 设为较大的值
output_size = 2 # 设为较大的值
batch_size = 2 # batch_size 为8的倍数
nums_batch = 10

class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
data = numpy.random.random([input_size]).astype('float16')
label = numpy.random.random([output_size]).astype('float16')
return data, label

def __len__(self):
return self.num_samples

dataset = RandomDataset(nums_batch * batch_size)
loader = paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True,
num_workers=0)

mse = paddle.nn.MSELoss()
model = SimpleNet(input_size, output_size) # 定义模型
optimizer = paddle.optimizer.Momentum(
learning_rate=0.0001,
parameters=model.parameters(),
multi_precision=True) # 定义优化器
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
model = paddle.amp.decorate(models=model, level='O2')

for i, (data, label) in enumerate(loader):
with paddle.amp.auto_cast(level='O2'):
output = model(data)
loss = mse(output, label)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad(set_to_zero=False)

if use_save_load and i == 5:
paddle.save(model.state_dict(), "model.pdparams")
paddle.save(optimizer.state_dict(), "opt.pdopt")
model.set_state_dict(paddle.load("model.pdparams"))
optimizer.set_state_dict(paddle.load("opt.pdopt"))

return loss.numpy()

def test_with_state_dict(self):
if core.is_compiled_with_cuda():
with fluid.dygraph.guard():
out_use_state_dict = self.check_with_opt_state_dict(
use_save_load=True)
out_no_state_dict = self.check_with_opt_state_dict(
use_save_load=False)
self.assertTrue(
np.array_equal(out_use_state_dict, out_no_state_dict))


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def state_dict(self):
for k, v in self._accumulators.items():
for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp
# if has master weight and then save master weight
if hasattr(self, "_master_weights"):
if len(self._master_weights) != 0:
state_dict["master_weights"] = self._master_weights
# global step if use lr decay
if isinstance(self._learning_rate, LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
Expand Down Expand Up @@ -304,6 +308,10 @@ def set_state_dict(self, state_dict):
state_dict = state_dict.copy()
if "LR_Scheduler" in state_dict:
state_dict.pop("LR_Scheduler")
if "master_weights" in state_dict:
if hasattr(self, "_master_weights"):
self._master_weights = state_dict["master_weights"]
state_dict.pop("master_weights")
self._accumulators_holder = state_dict
for k, v in self._accumulators.items():
for para_name, var_tmp in v.items():
Expand Down

0 comments on commit 3e6950d

Please sign in to comment.