Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-auto] add elementwise spmd rule for auto parallel #54373

Merged
merged 14 commits into from
Jul 7, 2023

Conversation

pkuzyc
Copy link
Contributor

@pkuzyc pkuzyc commented Jun 6, 2023

PR types

New features

PR changes

Others

Description

Pcard-70448
Add elementwise ops' spmd rule for inferring distributed attributes. Implement the InferForward function for elementwise op, i.e. infer output tensor's distributed attributes from input tensors'.

@paddle-bot
Copy link

paddle-bot bot commented Jun 6, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

op_desc = dist_op.serial_op.desc
input_name_list = []
output_name_list = []
input_name_list.append(op_desc.input('X')[0]) # 'X' is the arg name for op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be better that the wrap could take op as the only argument and user not need to border the input_name_list/output_name_list construction.

in order to achieve that wrap need to maintain the order of op argument slot from Phi API.


# Construct each input tensor's DistTensorSpec with shape and dist_attr
for name in input_names:
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in static mode, the dist attr of op is the destination dist attr.
here should use the source dist attr which is hold by tensor.

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Jun 14, 2023

Sorry to inform you that 325a998's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@pkuzyc pkuzyc force-pushed the elementwise_rule branch 2 times, most recently from 61cb076 to b97a6a1 Compare June 29, 2023 13:59
@@ -39,7 +39,7 @@ SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs,
}

std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove the "const" if you wouldn't modify the input argument ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "const" decorating vector forbids the modification of the input argument, so no need to add "const" inside the "pair".

@@ -24,6 +25,7 @@ namespace auto_parallel {

// matmul rule
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule);
REGISTER_SPMD_RULE(elementwise, ElementwiseSPMDRule);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the register name should be op_name, like: rule, elementwise_add, elementwise_div, elementwise_max, etc

to show the mapping of op_name to spmd_rule explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// Get dimsmapping for the given tensors. Return the pair of each
// tensor's einsum notation and the corresponding dimsmapping.
std::vector<std::pair<std::string, std::vector<int64_t>>> GetAxesShardingInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"GetAxesShardingInfo" is kind of Ambiguous, what about something like "GetAxesMappingsPair"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified to "GetAxesDimsMappingPair".


// step2.4: handle partial
// Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elementwise logic would not genenate partial, but it is ok here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted.


// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "ElementwiseSPMDRule InferForward: "
<< " Output dims_mapping: [" << str_join(output_dims_mapping)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input tensor might be reshard, therefore the src_dims_mapping and dst_dims_mapping of input tensor should also be logged.

const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of MatmulSPMDRule is NOT implemented yet."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatmulSPMDRule --> ElementwiseSPMDRule

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

} else if (shape[idim - start_dim] == 1) {
broadcast_axis_count[idim] += 1;
// mark the broadcast axis to a special "1"
axes_notation[idim - start_dim] = '1';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"1" concept is not need for sharding merge.

since we assume that the income mapping is correct for spmd rule, if any tensor axis's size is "1", and the dim_mapping for this axis should be "-1" correspondingly. and if all inputs' dim_mappings are "-1", it merge to "-1" with no doubt.

but it is ok here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here "1" is a special label for broadcasting dim, with this label broadcasting case can handled with the same function as common cases.

cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule)
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules)

cc_test_old(elementwise_spmd_rule_test SRCS ./elementwise_spmd_rule_test.cc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is elementwise_spmd_rule_test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added elementwise_spmd_rule_test.cc back, the test cases in elementwise_spmd_rule_test.cc is less than test_elementwise_rule.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry,what i want to say is remove this line while not getting elementwise_spmd_rule_test.cc back.

python unitest is perfected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])

# [0, 1, -1], [0] --> [0, 1, -1], [-1], [0, 1, -1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conflict fixing logic might change in future, but it is ok here.

@pkuzyc pkuzyc force-pushed the elementwise_rule branch 2 times, most recently from 5aecc33 to ffec781 Compare July 4, 2023 12:02
@pkuzyc pkuzyc force-pushed the elementwise_rule branch from ffec781 to 2ff031c Compare July 6, 2023 06:39
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule)
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules)

cc_test_old(elementwise_spmd_rule_test SRCS ./elementwise_spmd_rule_test.cc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry,what i want to say is remove this line while not getting elementwise_spmd_rule_test.cc back.

python unitest is perfected.


// step2.4: handle partial
// Step2.3.2 handle input tensor partial (TODO)
std::string log_str =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more detail info for log, in order to help debug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JZ-LIANG JZ-LIANG merged commit 8e5b0af into PaddlePaddle:develop Jul 7, 2023
@pkuzyc pkuzyc deleted the elementwise_rule branch July 12, 2023 03:55
cqulilujia pushed a commit to cqulilujia/Paddle that referenced this pull request Jul 24, 2023
…#54373)

* add some basic functions

* add elementwise rule for auto parallel

* add unit test for elementwise rule

* fix the lib name in spmd rule test cmake file

* fix some bugs

* add unit tests for elementwise spmd rule in python

* bug fix

* delete cpp unit test for elementwise spmd rule (use python ut now)

* add cpp unit test for elementwise rule

* use concrete op name in unit test

* fix typo

* fix code style

* delete cpp unit test

* add more details in log
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants