Skip to content

Commit

Permalink
follow comment
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Nov 13, 2023
1 parent 0297525 commit e13ea0b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs =
paddle::get<1>(dist_attr);
PADDLE_GET_CONST(std::vector<phi::distributed::TensorDistAttr>,
dist_attr);
auto out_size = dist_attrs.size();
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,8 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx,
true,
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));
const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs);
const auto& tensor_dist_attrs = PADDLE_GET_CONST(
std::vector<phi::distributed::TensorDistAttr>, dist_attrs);

PADDLE_ENFORCE_EQ(tensors.size(),
tensor_dist_attrs.size(),
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"""

MULTI_VECTOR_OUT_CREATION_TEMPLATE = """
auto dist_out_{out_name} = SetKernelDistOutput({size}, {in_name});
auto dist_out_{out_name} = SetKernelDistOutput({dist_output_arg}, {in_name});
std::vector<phi::DenseTensor*> dense_out_{out_name}(dist_out_{out_name}.size());
for (size_t i = 0; i < dist_out_{out_name}.size(); ++i) {{
dense_out_{out_name}[i] = const_cast<phi::DenseTensor*>(&dist_out_{out_name}[i]->value());
Expand Down Expand Up @@ -905,7 +905,7 @@ def generate_output_creation_code(self) -> str:
output_creation_code += (
MULTI_VECTOR_OUT_CREATION_TEMPLATE.format(
out_name=i,
size=dist_output_arg,
dist_output_arg=dist_output_arg,
in_name=get_out_code,
)
)
Expand Down
13 changes: 7 additions & 6 deletions paddle/phi/infermeta/spmd_rules/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ using phi::distributed::auto_parallel::str_join;

std::tuple<std::string, std::string> FillConcatNotation(int64_t n_axis,
int64_t concat_axis) {
PADDLE_ENFORCE_EQ(
n_axis > concat_axis,
true,
PADDLE_ENFORCE_GT(
n_axis,
concat_axis,
phi::errors::InvalidArgument(
"n_axis [%d] and concat_axis[%d]", n_axis, concat_axis));
static const std::string alphabet = "abcdefghijlopqrstuvwxyz";
PADDLE_ENFORCE_EQ(alphabet.size() > static_cast<size_t>(n_axis),
true,
phi::errors::InvalidArgument("n_axis %d", n_axis));
PADDLE_ENFORCE_GT(
alphabet.size(),
static_cast<size_t>(n_axis),
phi::errors::InvalidArgument("n_axis [%d] is too large", n_axis));
std::string all_axis = alphabet.substr(0, n_axis);
std::string align_axis =
std::string(all_axis.begin(), all_axis.begin() + concat_axis) +
Expand Down

0 comments on commit e13ea0b

Please sign in to comment.