Skip to content

Commit

Permalink
Merge branch 'infer_split_with_num' of https://github.com/liuzhenhai9…
Browse files Browse the repository at this point in the history
…3/Paddle into infer_split_with_num
  • Loading branch information
liuzhenhai93 committed Nov 13, 2023
2 parents 45bdb49 + 33bd39a commit 0297525
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def generate_output_creation_code(self) -> str:
# SetKernelDistOutput arg
dist_output_arg = (
"spmd_info.second[0]"
if self.generate_infer_spmd
if self.infer_meta['spmd_rule'] is not None
else self.outputs['out_size_expr'][0]
)
output_creation_code += VECTOR_OUT_CREATION_TEMPLATE.format(
Expand Down Expand Up @@ -899,7 +899,7 @@ def generate_output_creation_code(self) -> str:
else:
dist_output_arg = (
f"spmd_info.second[{i}]"
if self.generate_infer_spmd
if self.infer_meta['spmd_rule'] is not None
else self.outputs['out_size_expr'][i]
)
output_creation_code += (
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/infermeta/spmd_rules/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License. */
#include <limits>
#include <set>

#include "glog/logging.h"

#include "paddle/phi/infermeta/spmd_rules/elementwise.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

Expand Down Expand Up @@ -86,7 +88,11 @@ SpmdInfo ConcatInferSpmd(const std::vector<DistMetaTensor>& x, int axis) {
AlignDimsSharding(
&input_attrs, tensor_shapes, axis_names, {}, align_axis, true);

return {{input_attrs}, {input_attrs[non_empty_index]}};
auto out_dist_attr =
CopyTensorDistAttrForOutput(input_attrs[non_empty_index]);
out_dist_attr.set_dims_mapping(input_attrs[non_empty_index].dims_mapping());
VLOG(4) << "concat out " << out_dist_attr.to_string();
return {{input_attrs}, {out_dist_attr}};
}

SpmdInfo ConcatInferSpmdReverse(const std::vector<DistMetaTensor>& x,
Expand Down
28 changes: 11 additions & 17 deletions paddle/phi/infermeta/spmd_rules/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,33 +274,27 @@ void AlignDimsSharding(std::vector<TensorDistAttr>* input_attrs_ptr,

const auto& process_mess = input_attrs[non_empty_index].process_mesh();
auto has_mismatch = [&](int32_t mesh_dim) {
bool mismatch = false;
for (size_t i = 0; i < n_inputs; i++) {
if (IsEmpty(tensor_shapes[i])) {
continue;
}
auto& p_a = inputs_placements[non_empty_index][mesh_dim];
auto& p_b = inputs_placements[i][mesh_dim];
if (!p_a->is_shard()) {
if (!PlacementEqual(p_a, p_b)) {
mismatch = true;
break;
if (p_a->is_shard() && p_b->is_shard()) {
auto a_shard = std::dynamic_pointer_cast<ShardStatus>(p_a);
auto b_shard = std::dynamic_pointer_cast<ShardStatus>(p_b);
auto a_axis = axis_names[non_empty_index][a_shard->get_axis()];
auto b_axis = axis_names[i][b_shard->get_axis()];
if (a_axis != b_axis) {
return true;
}
}
if (!p_b->is_shard()) {
mismatch = true;
break;
}
auto a_shard = std::dynamic_pointer_cast<ShardStatus>(p_a);
auto b_shard = std::dynamic_pointer_cast<ShardStatus>(p_b);
auto a_axis = axis_names[non_empty_index][a_shard->get_axis()];
auto b_axis = axis_names[i][b_shard->get_axis()];
if (a_axis != b_axis) {
mismatch = true;
break;

if (!PlacementEqual(p_a, p_b)) {
return true;
}
}
return mismatch;
return false;
};

// a dim can not be sharded twice along diffrent mesh_dim
Expand Down

0 comments on commit 0297525

Please sign in to comment.