Skip to content

Commit

Permalink
fix utest for optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes committed Nov 6, 2023
1 parent 5b8a64a commit dbdd187
Show file tree
Hide file tree
Showing 12 changed files with 383 additions and 192 deletions.
40 changes: 22 additions & 18 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,25 +747,29 @@ def _create_master_weight(self, param):
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)

var_name = self._gen_master_weight_var_name(param)
var = paddle.static.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True,
)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
},
)

if framework.in_dygraph_mode():
var = paddle.cast(param, 'float32')
var.name = var_name
else:
var = paddle.static.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True,
)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
},
)
self._master_weights[param.name] = var
return var

Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_semi_auto_parallel_single_strategy MODULES
test_semi_auto_parallel_single_strategy)
set_tests_properties(test_semi_auto_parallel_single_strategy
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 300)
py_test_modules(test_semi_auto_parallel_hybrid_strategy MODULES
test_semi_auto_parallel_hybrid_strategy)
set_tests_properties(test_semi_auto_parallel_hybrid_strategy
Expand Down
72 changes: 38 additions & 34 deletions test/auto_parallel/semi_auto_parallel_simple_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import random

import numpy as np

Expand Down Expand Up @@ -80,9 +81,6 @@ def __init__(self):
)

paddle.set_device(self._backend)

self.init_input_data()

self.init_single_card_net_result()

def shard_fn(self, layer_name, layer, process_mesh):
Expand Down Expand Up @@ -124,56 +122,58 @@ def pp_shard_fn(self, layer_name, layer, process_mesh):
)
layer.bias = dist.shard_tensor(layer.bias, dist_attr=bias_dist_attr)

def init_input_data(self):
paddle.seed(self._seed)
np.random.seed(self._seed)
def set_random_seed(self, seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)

self.image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype(
'float32'
)
self.label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32')
def init_input_data(self):
image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32')
label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32')
return paddle.to_tensor(image), paddle.to_tensor(label)

def run_dynamic(self, layer, shard_input=False, is_pp=False):
paddle.seed(self._seed)
np.random.seed(self._seed)

# create loss
loss_fn = nn.MSELoss()

# run forward and backward
image = paddle.to_tensor(self.image)
input_mesh = self._pp_mesh0 if is_pp else self._mesh
if shard_input:
image = dist.shard_tensor(
image,
dist_attr=dist.DistAttr(
mesh=input_mesh, sharding_specs=['x', None]
),
)

out = layer(image)
label = paddle.to_tensor(self.label)

loss = loss_fn(out, label)

loss.backward()
opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
)
opt.step()
for _ in range(5):
image, label = self.init_input_data()
if shard_input:
image = dist.shard_tensor(
image,
dist_attr=dist.DistAttr(
mesh=input_mesh, sharding_specs=['x', None]
),
)

out = layer(image)
loss = loss_fn(out, label)
loss.backward()

opt.step()
opt.clear_grad()
return loss, layer.parameters()

def init_single_card_net_result(self):
self.set_random_seed(self._seed)
self.base_loss, self.base_parameters = self.run_dynamic(
DemoNet("demo_weight")
)

def check_tensor_eq(self, a, b):
np1 = a.numpy()
np2 = b.numpy()
np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True)
def check_tensor_eq(self, a, b, rtol=1e-05, atol=0, verbose=True):
np1 = a.astype("float32").numpy()
np2 = b.astype("float32").numpy()
np.testing.assert_allclose(
np1, np2, rtol=rtol, atol=atol, verbose=verbose
)

def test_dp_demo_net(self):
self.set_random_seed(self._seed)

self.dp_loss, self.dp_parameters = self.run_dynamic(
DemoNet("dp_demo_weight"),
shard_input=True,
Expand All @@ -184,6 +184,8 @@ def test_dp_demo_net(self):
self.check_tensor_eq(param.grad, param_base.grad)

def test_mp_demo_net(self):
self.set_random_seed(self._seed)

mp_layer = dist.shard_layer(
DemoNet("mp_demo_weight"), self._mesh, self.shard_fn
)
Expand All @@ -196,6 +198,8 @@ def test_mp_demo_net(self):
self.check_tensor_eq(param.grad, param_base.grad)

def test_pp_demo_net(self):
self.set_random_seed(self._seed)

# Send/Recv operators doens't support CPU now.
if self._backend != "gpu":
return
Expand Down
50 changes: 30 additions & 20 deletions test/auto_parallel/semi_auto_parallel_simple_net_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import os

import numpy as np
from semi_auto_parallel_simple_net import (
DemoNet,
TestSimpleNetForSemiAutoParallel,
Expand All @@ -37,44 +36,52 @@ def __init__(self):
self.init_single_card_net_result()

def run_dynamic_amp(self, layer, level='O1', shard_input=False):
paddle.seed(self._seed)
np.random.seed(self._seed)

if level == 'O2':
layer = paddle.amp.decorate(models=layer, level='O2')
# create loss
loss_fn = nn.MSELoss()
opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
# run forward and backward
image = paddle.to_tensor(self.image)
if shard_input:
image = dist.shard_tensor(
image,
dist_attr=dist.DistAttr(
mesh=self._mesh, sharding_specs=['x', None]
),
)

with paddle.amp.auto_cast(level=level):
out = layer(image)
label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)

scaled = scaler.scale(loss)
scaled.backward()
for _ in range(5):
image, label = self.init_input_data()
if shard_input:
image = dist.shard_tensor(
image,
dist_attr=dist.DistAttr(
mesh=self._mesh, sharding_specs=['x', None]
),
)

with paddle.amp.auto_cast(level=level):
out = layer(image)
loss = loss_fn(out, label)

scaled = scaler.scale(loss)
scaled.backward()
opt.step()
opt.clear_grad()

return loss, layer.parameters()

def init_single_card_net_result(self):
self.set_random_seed(self._seed)

(
self.base_loss_o1,
self.base_parameters_o1,
) = self.run_dynamic_amp(DemoNet('demo_weight_O1'), 'O1')

self.set_random_seed(self._seed)
(
self.base_loss_o2,
self.base_parameters_o2,
) = self.run_dynamic_amp(DemoNet('demo_weight_O2'), 'O2')

def test_dp_demo_net(self):
self.set_random_seed(self._seed)
(
self.dp_loss_o1,
self.dp_parameters_o1,
Expand All @@ -88,6 +95,7 @@ def test_dp_demo_net(self):
# self.check_tensor_eq(param, param_base)
self.check_tensor_eq(param.grad, param_base.grad)

self.set_random_seed(self._seed)
(
self.dp_loss_o2,
self.dp_parameters_o2,
Expand All @@ -100,6 +108,7 @@ def test_dp_demo_net(self):
self.check_tensor_eq(param.grad, param_base.grad)

def test_mp_demo_net(self):
self.set_random_seed(self._seed)
mp_layer_o1 = dist.shard_layer(
DemoNet("mp_demo_weight_O1"), self._mesh, self.shard_fn
)
Expand All @@ -114,6 +123,7 @@ def test_mp_demo_net(self):
self.check_tensor_eq(param, param_base)
self.check_tensor_eq(param.grad, param_base.grad)

self.set_random_seed(self._seed)
mp_layer_o2 = dist.shard_layer(
DemoNet("mp_demo_weight_O2"), self._mesh, self.shard_fn
)
Expand Down
38 changes: 20 additions & 18 deletions test/auto_parallel/semi_auto_parallel_simple_net_clear_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,32 @@ def __init__(self):
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

paddle.set_device(self._backend)
self.init_input_data()

def run_dynamic_clear_gradient(self, layer, shard_input=False):
# create loss
loss_fn = nn.MSELoss()
opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
)
# run forward and backward
image = paddle.to_tensor(self.image)
if shard_input:
image = dist.shard_tensor(
image,
dist_attr=dist.DistAttr(
mesh=self._mesh, sharding_specs=['x', None]
),
)
out = layer(image)

label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)

loss.backward()
for _ in range(5):
image, label = self.init_input_data()
if shard_input:
image = dist.shard_tensor(
image,
dist_attr=dist.DistAttr(
mesh=self._mesh, sharding_specs=['x', None]
),
)
out = layer(image)
loss = loss_fn(out, label)

for param in layer.parameters():
param.clear_gradient()
param.clear_gradient(False)
loss.backward()
opt.step()
opt.clear_grad()
for param in layer.parameters():
param.clear_gradient()
param.clear_gradient(False)

def test_demo_net(self):
mp_layer = dist.shard_layer(
Expand Down
4 changes: 1 addition & 3 deletions test/auto_parallel/semi_auto_parallel_simple_net_grad_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ def __init__(self):
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

paddle.set_device(self._backend)
self.init_input_data()

def run_dynamic_grad_api(self, layer, shard_input=False):
# create loss
loss_fn = nn.MSELoss()
# run forward and backward
image = paddle.to_tensor(self.image)
image, label = self.init_input_data()
if shard_input:
image = dist.shard_tensor(
image,
Expand All @@ -50,7 +49,6 @@ def run_dynamic_grad_api(self, layer, shard_input=False):
)
out = layer(image)

label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)

loss.backward()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,12 @@ def __init__(self):
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

paddle.set_device(self._backend)
self.init_input_data()

def run_dynamic(self, layer):
image, label = self.init_input_data()
loss_fn = nn.MSELoss()
image = paddle.to_tensor(self.image)

out = layer(image)
label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)
loss.backward()

Expand Down
Loading

0 comments on commit dbdd187

Please sign in to comment.