Skip to content

Commit

Permalink
[AutoParallel] Simplify semi auto parallel simple net test (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#57998)

* simplify auto parallel simple net test

* remove useless argument
  • Loading branch information
chenwhql authored and Frida-a committed Oct 14, 2023
1 parent ad10756 commit 70f1cc0
Showing 1 changed file with 26 additions and 50 deletions.
76 changes: 26 additions & 50 deletions test/auto_parallel/semi_auto_parallel_simple_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,36 +54,30 @@ def forward(self, x):
class DPDemoNet(nn.Layer):
def __init__(self, np_w0, np_w1, mesh):
super().__init__()
self.replicate_dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=[None, None]
)
self.shard_axis0_dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None]
)
self.w0 = dist.shard_tensor(
self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
attr=paddle.framework.ParamAttr(
name="dp_demo_weight_1",
initializer=paddle.nn.initializer.Assign(np_w0),
),
self.mesh = mesh
self.w0 = self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
attr=paddle.framework.ParamAttr(
name="dp_demo_weight_1",
initializer=paddle.nn.initializer.Assign(np_w0),
),
dist_attr=self.replicate_dist_attr,
)
self.w1 = dist.shard_tensor(
self.create_parameter(
shape=[IMAGE_SIZE, CLASS_NUM],
attr=paddle.framework.ParamAttr(
name="dp_nemo_weight_2",
initializer=paddle.nn.initializer.Assign(np_w1),
),
self.w1 = self.create_parameter(
shape=[IMAGE_SIZE, CLASS_NUM],
attr=paddle.framework.ParamAttr(
name="dp_nemo_weight_2",
initializer=paddle.nn.initializer.Assign(np_w1),
),
dist_attr=self.replicate_dist_attr,
)

def forward(self, x):
y = paddle.matmul(
dist.shard_tensor(x, dist_attr=self.shard_axis0_dist_attr),
dist.shard_tensor(
x,
dist_attr=dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None]
),
),
self.w0,
)
z = paddle.matmul(y, self.w1)
Expand All @@ -93,15 +87,6 @@ def forward(self, x):
class MPDemoNet(nn.Layer):
def __init__(self, np_w0, np_w1, mesh):
super().__init__()
self.replicate_dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=[None, None]
)
self.shard_axis0_dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None]
)
self.shard_axis1_dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None]
)
self.w0 = dist.shard_tensor(
self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
Expand All @@ -110,7 +95,7 @@ def __init__(self, np_w0, np_w1, mesh):
initializer=paddle.nn.initializer.Assign(np_w0),
),
),
dist_attr=self.shard_axis1_dist_attr,
dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']),
)
self.w1 = dist.shard_tensor(
self.create_parameter(
Expand All @@ -120,13 +105,11 @@ def __init__(self, np_w0, np_w1, mesh):
initializer=paddle.nn.initializer.Assign(np_w1),
),
),
dist_attr=self.shard_axis0_dist_attr,
dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]),
)

def forward(self, x):
y = paddle.matmul(
dist.shard_tensor(x, dist_attr=self.replicate_dist_attr), self.w0
)
y = paddle.matmul(x, self.w0)
z = paddle.matmul(y, self.w1)
return z

Expand Down Expand Up @@ -156,23 +139,16 @@ def init_input_data(self):
self.w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32')

# TODO(chenweihang): optimizer cannot run auto-parallel now
def run_dynamic(self, layer, parallel=False):
def run_dynamic(self, layer):
# create loss
loss_fn = nn.MSELoss()
# run forward and backward
image = paddle.to_tensor(self.image)
out = layer(image)
label = (
dist.shard_tensor(
self.label,
dist_attr=dist.DistAttr(
mesh=self._mesh, sharding_specs=[None, None]
),
)
if parallel is True
else paddle.to_tensor(self.label)
)

label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)

loss.backward()
return loss, layer.w0.grad, layer.w1.grad

Expand All @@ -188,15 +164,15 @@ def check_tensor_eq(self, a, b):

def test_dp_demo_net(self):
self.dp_loss, self.dp_w0_grad, self.dp_w1_grad = self.run_dynamic(
DPDemoNet(self.w0, self.w1, self._mesh), parallel=True
DPDemoNet(self.w0, self.w1, self._mesh)
)
self.check_tensor_eq(self.dp_loss, self.base_loss)
self.check_tensor_eq(self.dp_w0_grad, self.base_w0_grad)
self.check_tensor_eq(self.dp_w1_grad, self.base_w1_grad)

def test_mp_demo_net(self):
self.mp_loss, self.mp_w0_grad, self.mp_w1_grad = self.run_dynamic(
MPDemoNet(self.w0, self.w1, self._mesh), parallel=True
MPDemoNet(self.w0, self.w1, self._mesh)
)
self.check_tensor_eq(self.mp_loss, self.base_loss)
self.check_tensor_eq(self.mp_w0_grad, self.base_w0_grad)
Expand Down

0 comments on commit 70f1cc0

Please sign in to comment.