Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel]: Support optional, inplace input and output for DistTensor. #57092

Merged
merged 19 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
4edd727
[WIP] Support std::vector<phi::Tensor> input and output for DistTensor.
GhostScreaming Aug 23, 2023
e8f28a8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Aug 24, 2023
fed53d3
Polish code for new dist tensor implementation.
GhostScreaming Aug 28, 2023
651c205
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Aug 31, 2023
d4a1653
Fix bug of DistTensor upgrade. Add support functions for std::vector<…
GhostScreaming Aug 31, 2023
60b0d50
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 4, 2023
3a6be45
Add support for DistTensor type of std::vector<phi::Tensor> as input …
GhostScreaming Sep 4, 2023
cd716d7
Polish code. Remove useless comments.
GhostScreaming Sep 4, 2023
14b7ebe
Add update_loss_scaling in skip_op_lists.
GhostScreaming Sep 4, 2023
26c149e
Polish code.
GhostScreaming Sep 4, 2023
711472c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 7, 2023
ac60c39
[Auto Parallel]: Support paddle::optional<Tensor> and
GhostScreaming Sep 8, 2023
ee384ac
Polish code.
GhostScreaming Sep 8, 2023
162cab9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GhostScreaming Sep 8, 2023
5110b0d
Polish code. And support inplace Tensor, std::vector<Tensor>, paddle:…
GhostScreaming Sep 12, 2023
2679575
Polish testcase code. Add testcase for inplace paddle::optional<phi::…
GhostScreaming Sep 12, 2023
f387180
Remove useless codes in testcase code.
GhostScreaming Sep 12, 2023
dff5bba
Polish code style.
GhostScreaming Sep 12, 2023
c0bd9b7
Polish code style. And fix problems of testcases.
GhostScreaming Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,28 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out) {
std::vector<phi::distributed::DistTensor*> results(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
return results;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out) {
std::vector<phi::distributed::DistTensor*> results;
if (out) {
results = std::vector<phi::distributed::DistTensor*>(out->size(), nullptr);
for (size_t i = 0; i < out->size(); ++i) {
results[i] =
static_cast<phi::distributed::DistTensor*>(out->at(i).impl().get());
}
}
return results;
}

} // namespace experimental
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,11 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
size_t out_size, std::vector<Tensor>* out);

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
size_t out_size, std::vector<Tensor>* out);

std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out);

} // namespace experimental
} // namespace paddle
44 changes: 34 additions & 10 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> out;
for (auto x : input) {
for (auto& x : input) {
const auto& tensor_in = x.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
Expand All @@ -691,22 +691,46 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
dense_tensor.meta().is_contiguous()))) {
out.push_back(
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
continue;
} else {
phi::DenseTensor trans_in_tensor = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
}
phi::DenseTensor trans_in_tensor = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
} else {
out.push_back(nullptr);
}
}
return out;
}

paddle::optional<phi::distributed::DistTensor> PrepareDataForDistTensor(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
return {*PrepareDataForDistTensor(
*input, target_args_def, transform_flag, is_stride_kernel)};
}
return paddle::none;
}

paddle::optional<std::vector<std::shared_ptr<phi::distributed::DistTensor>>>
PrepareDataForDistTensor(const paddle::optional<std::vector<Tensor>>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
if (input) {
return PrepareDataForDistTensor(
*input, target_args_def, transform_flag, is_stride_kernel);
}
return paddle::none;
}

} // namespace experimental
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,17 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
const TransformFlag& transform_flag,
bool is_stride_kernel);

paddle::optional<phi::distributed::DistTensor> PrepareDataForDistTensor(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);

paddle::optional<std::vector<std::shared_ptr<phi::distributed::DistTensor>>>
PrepareDataForDistTensor(const paddle::optional<std::vector<Tensor>>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);

} // namespace experimental
} // namespace paddle
Loading