-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy #57505
[AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy #57505
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… ap/generate_replicated_spmd
… ap/generate_replicated_spmd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测超时时间设置
…strategy (PaddlePaddle#57505) * generate forward defalut spmd * generate bwd default spmd rule * test relu and mse forward success * test mse loss fwd and bwd * updarte replicated rule name * update single strategy test * add unittests * polish details * remove useless seed * fix dist branch test error
…strategy (PaddlePaddle#57505) * generate forward defalut spmd * generate bwd default spmd rule * test relu and mse forward success * test mse loss fwd and bwd * updarte replicated rule name * update single strategy test * add unittests * polish details * remove useless seed * fix dist branch test error
…strategy (PaddlePaddle#57505) * generate forward defalut spmd * generate bwd default spmd rule * test relu and mse forward success * test mse loss fwd and bwd * updarte replicated rule name * update single strategy test * add unittests * polish details * remove useless seed * fix dist branch test error
PR types
New features
PR changes
Others
Description
Pcard-73145
[AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy
本PR将通用的切分推导规则与转换逻辑生成至仅包含Tensor输入或输出的API中,通用的切分推导是将API的输入整体转换为Replicate状态,再进行Kernel运算,相当于每个节点都单独进行完整的运算。
在该规则生成之后,虽然性能较差,但相当一部分API可以测试动半的基础执行流程。目前具备专用切分推导规则的仅有matmul一个算子(且反向尚不完备),其他算子前反向切分推导规则将会是周期相对较长的逐算子扩量工作,通用规则的存在确保动半架构执行时不会因为切分推导策略不存在而直接失败。
本PR基于以上状况,通过一个简单的Demo网络,验证在动半架构下,DP、MP单策略的正确性。
Demo网络
仅包含两个matmul,以及一个mse loss(包含subtract、square、mean三个算子),仅执行前反向
DP demo改写
DP切分执行示意图:
![image](https://private-user-images.githubusercontent.com/22561442/269464817-9536a422-3c18-47a6-bd97-9e18d96f951a.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1ODQxMDMsIm5iZiI6MTczOTU4MzgwMywicGF0aCI6Ii8yMjU2MTQ0Mi8yNjk0NjQ4MTctOTUzNmE0MjItM2MxOC00N2E2LWJkOTctOWUxOGQ5NmY5NTFhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE1VDAxNDMyM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTY0NzVhNTUwNDQ3MzhmYTk5OTNiMWI5NzM5ZTJiMzhjYjdjZWI1NjQzNGRiYmJmMjhiOTUyOTRmY2Y1MmNlNmYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.Zbq2kxSwnwaKtfPVPA32dTvpIf8ekZ8sLBVHTh_r2sw)
MP demo改写
MP切分执行示意图:
![image](https://private-user-images.githubusercontent.com/22561442/269466449-43697039-3f0d-4a1d-89b6-45dee2256ac5.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1ODQxMDMsIm5iZiI6MTczOTU4MzgwMywicGF0aCI6Ii8yMjU2MTQ0Mi8yNjk0NjY0NDktNDM2OTcwMzktM2YwZC00YTFkLTg5YjYtNDVkZWUyMjU2YWM1LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE1VDAxNDMyM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTFhMTEwMjE1YjQ0NzI4ZGMxOWE2MzRkYzRjMjA1OTIwMjg3NzZjMWY5YTE1NzAzZDk5N2FiMWIzMzliMzZhM2UmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.vo88-jSRBSgzPPIOicvmPklXQZykEga5yySHYByc-mI)
测试原理
动半模式下的Tensor具有全局视角,即用户打印任意tensor值,均应该拿到与单卡一样的结果,如果取值时tensor处于Shard或者Replicate状态,则会自动触发通信补全数据
原方案调整TODO
目前Demo改写的复杂度比较高,易用性不佳。我们原先的方案约束是,动半模式下API的所有输入均需要是DsitTensor,这导致用户不仅需要对关键参数进行切分,还需要将其他不进行切分的输入均通过shard_tensor api由DenseTensor转换成Replicate的DistTensor,比如label数据,比如Optimizer的learning_rate(用户传入的是float,无法显式切分)。目前看来这个方案约束需要调整,允许输入存在DenseTensor,并自动将DenseTensor转换为Replicate的DistTensor,否则写法过于复杂,该功能目前已在开发中。
其他改动说明