Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy #57505

Merged
merged 13 commits into from
Sep 22, 2023

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Sep 19, 2023

PR types

New features

PR changes

Others

Description

Pcard-73145

[AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy

本PR将通用的切分推导规则与转换逻辑生成至仅包含Tensor输入或输出的API中,通用的切分推导是将API的输入整体转换为Replicate状态,再进行Kernel运算,相当于每个节点都单独进行完整的运算。

在该规则生成之后,虽然性能较差,但相当一部分API可以测试动半的基础执行流程。目前具备专用切分推导规则的仅有matmul一个算子(且反向尚不完备),其他算子前反向切分推导规则将会是周期相对较长的逐算子扩量工作,通用规则的存在确保动半架构执行时不会因为切分推导策略不存在而直接失败

本PR基于以上状况,通过一个简单的Demo网络,验证在动半架构下,DP、MP单策略的正确性。

混合策略需要两卡以上的环境,目前CI分布式单测仅能运行两卡

Demo网络

仅包含两个matmul,以及一个mse loss(包含subtract、square、mean三个算子),仅执行前反向

  1. 为什么不用CrossEntropyLoss?(CrossEntropyLoss包含12个算子,其中部分算子尚不支持动半)
  2. 为什么不增加Optimizer?(Optimizer均包含optional输入,目前尚不支持动半,且learning_rate tensor用户无法切分)
  3. 固定网络参数为numpy array确保测试数据一致
class DemoNet(nn.Layer):
    def __init__(self, np_w0, np_w1):
        super().__init__()
        self.w0 = self.create_parameter(
            shape=[IMAGE_SIZE, IMAGE_SIZE],
            attr=paddle.framework.ParamAttr(
                name="demo_weight_1",
                initializer=paddle.nn.initializer.Assign(np_w0),
            ),
        )
        self.w1 = self.create_parameter(
            shape=[IMAGE_SIZE, CLASS_NUM],
            attr=paddle.framework.ParamAttr(
                name="nemo_weight_2",
                initializer=paddle.nn.initializer.Assign(np_w1),
            ),
        )

    def forward(self, x):
        y = paddle.matmul(x, self.w0)
        z = paddle.matmul(y, self.w1)
        return z

def run_dynamic(self, layer, parallel=False):
    # create loss
    loss_fn = nn.MSELoss()
    # run forward and backward
    image = paddle.to_tensor(self.image)
    out = layer(image)
    label = (
        dist.shard_tensor(
            self.label,
            dist_attr=dist.DistAttr(
                mesh=self._mesh, sharding_specs=[None, None]
            ),
        )
        if parallel is True
        else paddle.to_tensor(self.label)
    )
    loss = loss_fn(out, label)
    loss.backward()

DP demo改写

class DPDemoNet(nn.Layer):
    def __init__(self, np_w0, np_w1, mesh):
        super().__init__()
        self.replicate_dist_attr = dist.DistAttr(
            mesh=mesh, sharding_specs=[None, None]
        )
        self.shard_axis0_dist_attr = dist.DistAttr(
            mesh=mesh, sharding_specs=['x', None]
        )
        self.w0 = dist.shard_tensor(
            self.create_parameter(
                shape=[IMAGE_SIZE, IMAGE_SIZE],
                attr=paddle.framework.ParamAttr(
                    name="dp_demo_weight_1",
                    initializer=paddle.nn.initializer.Assign(np_w0),
                ),
            ),
            dist_attr=self.replicate_dist_attr,
        )
        self.w1 = dist.shard_tensor(
            self.create_parameter(
                shape=[IMAGE_SIZE, CLASS_NUM],
                attr=paddle.framework.ParamAttr(
                    name="dp_nemo_weight_2",
                    initializer=paddle.nn.initializer.Assign(np_w1),
                ),
            ),
            dist_attr=self.replicate_dist_attr,
        )

    def forward(self, x):
       # 切分输入x
        y = paddle.matmul(
            dist.shard_tensor(x, dist_attr=self.shard_axis0_dist_attr),
            self.w0,
        )
        z = paddle.matmul(y, self.w1)
        return z

DP切分执行示意图:
image

MP demo改写

class MPDemoNet(nn.Layer):
    def __init__(self, np_w0, np_w1, mesh):
        super().__init__()
        self.replicate_dist_attr = dist.DistAttr(
            mesh=mesh, sharding_specs=[None, None]
        )
        self.shard_axis0_dist_attr = dist.DistAttr(
            mesh=mesh, sharding_specs=['x', None]
        )
        self.shard_axis1_dist_attr = dist.DistAttr(
            mesh=mesh, sharding_specs=['x', None]
        )
        # 切分参数w0
        self.w0 = dist.shard_tensor(
            self.create_parameter(
                shape=[IMAGE_SIZE, IMAGE_SIZE],
                attr=paddle.framework.ParamAttr(
                    name="mp_demo_weight_1",
                    initializer=paddle.nn.initializer.Assign(np_w0),
                ),
            ),
            dist_attr=self.shard_axis1_dist_attr,
        )
        # 切分参数w1
        self.w1 = dist.shard_tensor(
            self.create_parameter(
                shape=[IMAGE_SIZE, CLASS_NUM],
                attr=paddle.framework.ParamAttr(
                    name="mp_nemo_weight_2",
                    initializer=paddle.nn.initializer.Assign(np_w1),
                ),
            ),
            dist_attr=self.shard_axis0_dist_attr,
        )

    def forward(self, x):
        y = paddle.matmul(
            dist.shard_tensor(x, dist_attr=self.replicate_dist_attr), self.w0
        )
        z = paddle.matmul(y, self.w1)
        return z

MP切分执行示意图:
image

测试原理

动半模式下的Tensor具有全局视角,即用户打印任意tensor值,均应该拿到与单卡一样的结果,如果取值时tensor处于Shard或者Replicate状态,则会自动触发通信补全数据

原方案调整TODO

目前Demo改写的复杂度比较高,易用性不佳。我们原先的方案约束是,动半模式下API的所有输入均需要是DsitTensor,这导致用户不仅需要对关键参数进行切分,还需要将其他不进行切分的输入均通过shard_tensor api由DenseTensor转换成Replicate的DistTensor,比如label数据,比如Optimizer的learning_rate(用户传入的是float,无法显式切分)。目前看来这个方案约束需要调整,允许输入存在DenseTensor,并自动将DenseTensor转换为Replicate的DistTensor,否则写法过于复杂,该功能目前已在开发中。

其他改动说明

  1. TCPStore日志等级较高,在API调试时重复建立通信日志过多,难以调试,降低了日志等级
  2. 在Reshard函数中添加了日志用于分析API切分转换行为

@paddle-bot
Copy link

paddle-bot bot commented Sep 19, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@chenwhql chenwhql changed the title [AutoParallel] Generate replicated spmd for PHI API [AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy Sep 21, 2023
Copy link
Contributor

@LiYuRio LiYuRio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@GhostScreaming GhostScreaming left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@raindrops2sea raindrops2sea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测超时时间设置

@chenwhql chenwhql merged commit 9796bb8 into PaddlePaddle:develop Sep 22, 2023
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
…strategy (PaddlePaddle#57505)

* generate forward defalut spmd

* generate bwd default spmd rule

* test relu and mse forward success

* test mse loss fwd and bwd

* updarte replicated rule name

* update single strategy test

* add unittests

* polish details

* remove useless seed

* fix dist branch test error
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 16, 2023
…strategy (PaddlePaddle#57505)

* generate forward defalut spmd

* generate bwd default spmd rule

* test relu and mse forward success

* test mse loss fwd and bwd

* updarte replicated rule name

* update single strategy test

* add unittests

* polish details

* remove useless seed

* fix dist branch test error
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…strategy (PaddlePaddle#57505)

* generate forward defalut spmd

* generate bwd default spmd rule

* test relu and mse forward success

* test mse loss fwd and bwd

* updarte replicated rule name

* update single strategy test

* add unittests

* polish details

* remove useless seed

* fix dist branch test error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants