-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-auto] add elementwise spmd rule for auto parallel #54373
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
op_desc = dist_op.serial_op.desc | ||
input_name_list = [] | ||
output_name_list = [] | ||
input_name_list.append(op_desc.input('X')[0]) # 'X' is the arg name for op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be better that the wrap could take op as the only argument and user not need to border the input_name_list/output_name_list construction.
in order to achieve that wrap need to maintain the order of op argument slot from Phi API.
|
||
# Construct each input tensor's DistTensorSpec with shape and dist_attr | ||
for name in input_names: | ||
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in static mode, the dist attr of op is the destination dist attr.
here should use the source dist attr which is hold by tensor.
Sorry to inform you that 325a998's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
61cb076
to
b97a6a1
Compare
@@ -39,7 +39,7 @@ SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs, | |||
} | |||
|
|||
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( | |||
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove the "const" if you wouldn't modify the input argument ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "const" decorating vector forbids the modification of the input argument, so no need to add "const" inside the "pair".
@@ -24,6 +25,7 @@ namespace auto_parallel { | |||
|
|||
// matmul rule | |||
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); | |||
REGISTER_SPMD_RULE(elementwise, ElementwiseSPMDRule); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the register name should be op_name, like: rule, elementwise_add, elementwise_div, elementwise_max, etc
to show the mapping of op_name to spmd_rule explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// Get dimsmapping for the given tensors. Return the pair of each | ||
// tensor's einsum notation and the corresponding dimsmapping. | ||
std::vector<std::pair<std::string, std::vector<int64_t>>> GetAxesShardingInfo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"GetAxesShardingInfo" is kind of Ambiguous, what about something like "GetAxesMappingsPair"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified to "GetAxesDimsMappingPair".
|
||
// step2.4: handle partial | ||
// Step2.3.1 Output Partial | ||
std::vector<int64_t> partial_on_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elementwise logic would not genenate partial, but it is ok here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted.
|
||
// Step2.3.2 handle input tensor partial (TODO) | ||
VLOG(4) << "ElementwiseSPMDRule InferForward: " | ||
<< " Output dims_mapping: [" << str_join(output_dims_mapping) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input tensor might be reshard, therefore the src_dims_mapping and dst_dims_mapping of input tensor should also be logged.
const std::vector<DistTensorSpec>& output_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
PADDLE_THROW(phi::errors::Unimplemented( | ||
"InferBackward of MatmulSPMDRule is NOT implemented yet.")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatmulSPMDRule --> ElementwiseSPMDRule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} else if (shape[idim - start_dim] == 1) { | ||
broadcast_axis_count[idim] += 1; | ||
// mark the broadcast axis to a special "1" | ||
axes_notation[idim - start_dim] = '1'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"1" concept is not need for sharding merge.
since we assume that the income mapping is correct for spmd rule, if any tensor axis's size is "1", and the dim_mapping for this axis should be "-1" correspondingly. and if all inputs' dim_mappings are "-1", it merge to "-1" with no doubt.
but it is ok here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here "1" is a special label for broadcasting dim, with this label broadcasting case can handled with the same function as common cases.
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule) | ||
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules) | ||
|
||
cc_test_old(elementwise_spmd_rule_test SRCS ./elementwise_spmd_rule_test.cc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is elementwise_spmd_rule_test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added elementwise_spmd_rule_test.cc back, the test cases in elementwise_spmd_rule_test.cc is less than test_elementwise_rule.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry,what i want to say is remove this line while not getting elementwise_spmd_rule_test.cc back.
python unitest is perfected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1]) | ||
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) | ||
|
||
# [0, 1, -1], [0] --> [0, 1, -1], [-1], [0, 1, -1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conflict fixing logic might change in future, but it is ok here.
5aecc33
to
ffec781
Compare
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule) | ||
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules) | ||
|
||
cc_test_old(elementwise_spmd_rule_test SRCS ./elementwise_spmd_rule_test.cc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry,what i want to say is remove this line while not getting elementwise_spmd_rule_test.cc back.
python unitest is perfected.
|
||
// step2.4: handle partial | ||
// Step2.3.2 handle input tensor partial (TODO) | ||
std::string log_str = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more detail info for log, in order to help debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…#54373) * add some basic functions * add elementwise rule for auto parallel * add unit test for elementwise rule * fix the lib name in spmd rule test cmake file * fix some bugs * add unit tests for elementwise spmd rule in python * bug fix * delete cpp unit test for elementwise spmd rule (use python ut now) * add cpp unit test for elementwise rule * use concrete op name in unit test * fix typo * fix code style * delete cpp unit test * add more details in log
PR types
New features
PR changes
Others
Description
Pcard-70448
Add elementwise ops' spmd rule for inferring distributed attributes. Implement the InferForward function for elementwise op, i.e. infer output tensor's distributed attributes from input tensors'.