Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Generate replicated spmd for PHI API and verify DP MP strategy #57505

Merged
merged 13 commits into from
Sep 22, 2023
13 changes: 7 additions & 6 deletions paddle/fluid/eager/tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ class TensorWrapper {
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
dense_tensor->meta()));
} else if (phi::distributed::DistTensor::classof(tensor.impl().get())) {
// Only Copy Meta
// Copy Global dims, DistAttr and DenseTensorMeta
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor.impl().get());
intermidiate_tensor_.set_impl(
auto no_buffer_dist_tensor =
std::make_shared<phi::distributed::DistTensor>(
phi::DenseTensor(std::make_shared<phi::Allocation>(
nullptr, 0, tensor.place()),
dist_tensor->value().meta()),
dist_tensor->dist_attr()));
dist_tensor->dims(), dist_tensor->dist_attr());
*no_buffer_dist_tensor->unsafe_mutable_value() = phi::DenseTensor(
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
dist_tensor->value().meta());
intermidiate_tensor_.set_impl(no_buffer_dist_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
Expand Down
32 changes: 22 additions & 10 deletions paddle/fluid/eager/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ class EagerUtils {
"Type: %s, Dtype: %s, Place: %s, Shape: %s, DistAttr: %s";
std::string tensor_info_str = "";
if (t.defined()) {
if (t.initialized()) {
if (t.is_dist_tensor()) {
auto dist_t =
std::static_pointer_cast<phi::distributed::DistTensor>(t.impl());
if (t.is_dist_tensor()) {
auto dist_t =
std::static_pointer_cast<phi::distributed::DistTensor>(t.impl());
if (t.initialized()) {
tensor_info_str += paddle::string::Sprintf(
TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
Expand All @@ -321,20 +321,32 @@ class EagerUtils {
"%s, Local Shape: %s", t.dims(), dist_t->local_dims()),
dist_t->dist_attr());
} else {
tensor_info_str +=
paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
"Unknown",
"Unknown",
t.dims(),
dist_t->dist_attr());
}
} else {
if (t.initialized()) {
tensor_info_str +=
paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
t.dtype(),
t.place().DebugString(),
t.dims(),
"Unknown");
} else {
tensor_info_str +=
paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
"Unknown",
"Unknown",
"Unknown",
"Unknown");
}
} else {
tensor_info_str += paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
"Unknown",
"Unknown",
"Unknown");
}
} else {
tensor_info_str += "Unknown";
Expand Down
18 changes: 16 additions & 2 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,12 @@ phi::distributed::DistTensor* SetKernelDistOutput(
}

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
const phi::distributed::TensorDistAttr& dist_attr) {
return std::make_shared<phi::distributed::DistTensor>(phi::DDim(), dist_attr);
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
if (out) {
return std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attr);
}
return nullptr;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
Expand Down Expand Up @@ -617,6 +621,16 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
}
return results;
}
void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh) {
if (out) {
auto dist_attr =
phi::distributed::TensorDistAttr(phi::vectorize(out->dims()));
dist_attr.set_process_mesh(process_mesh);
out->unsafe_set_dist_attr(dist_attr);
}
}

} // namespace experimental
} // namespace paddle
11 changes: 10 additions & 1 deletion paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ phi::distributed::DistTensor* SetKernelDistOutput(
phi::distributed::TensorDistAttr());

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
const phi::distributed::TensorDistAttr& dist_attr);
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);
Expand All @@ -159,5 +161,12 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out);

// DistTensor need to set initial dist attr after the dims setted, it is
// constructed based dims and current process mesh, beforce calling this
// function, the out should hold correct dims
void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh);

} // namespace experimental
} // namespace paddle
83 changes: 63 additions & 20 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,14 @@ void TransDataBackend(const phi::SelectedRows* tensor,

/* ------------------ for auto parallel ----------------------- */

static bool ReshardIsNeeded(
const phi::distributed::TensorDistAttr& in_dist_attr,
const phi::distributed::TensorDistAttr& out_dist_attr) {
return (in_dist_attr.process_mesh() != out_dist_attr.process_mesh() ||
in_dist_attr.dims_mapping() != out_dist_attr.dims_mapping() ||
in_dist_attr.partial_status() != out_dist_attr.partial_status());
}

std::string ReshardDebugInfo(
const phi::distributed::DistTensor& src_tensor,
const phi::distributed::TensorDistAttr& dist_attr) {
Expand All @@ -620,8 +628,8 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->dist_attr() != dist_attr) {
VLOG(6) << "FwdAPI ApiIn to KernelIn - "
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
VLOG(6) << "ApiIn to KernelIn - "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
Expand All @@ -632,6 +640,36 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
return nullptr;
}

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
VLOG(6) << "ApiIn to Replicated KernelIn - "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
if (dist_tensor->initialized()) {
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
return func->Eval(dev_ctx, *dist_tensor, dist_attr);
} else {
// when no tensor data need to be reshard, we still need to set correct
// replicated dist attr and local dims for output
dist_tensor->unsafe_set_dist_attr(dist_attr);
auto dense_tensor_meta = dist_tensor->value().meta();
dense_tensor_meta.dims = dist_tensor->dims();
dist_tensor->unsafe_mutable_value()->set_meta(dense_tensor_meta);
}
}
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
return nullptr;
}

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
if (out_tensor->dist_attr().is_partial()) {
Expand All @@ -649,25 +687,30 @@ void ReshardKernelOutputToApiOutput(
phi::DeviceContext* dev_ctx,
const std::shared_ptr<phi::distributed::DistTensor>& src_tensor,
Tensor* dst_tensor) {
auto tensor_out = dst_tensor->impl();
PADDLE_ENFORCE_NE(
tensor_out,
nullptr,
phi::errors::InvalidArgument("The output tensor is nullptr."));
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_out.get());
dist_tensor->unsafe_set_dims(src_tensor->dims());
if (src_tensor->dist_attr() != dist_tensor->dist_attr()) {
VLOG(6) << "BwdAPI KernelOut to ApiOut - "
<< ReshardDebugInfo(*src_tensor, dist_tensor->dist_attr());
auto* func = phi::distributed::ChooseProperReshardFunction(
*src_tensor, dist_tensor->dist_attr());
func->Eval(dev_ctx, *src_tensor, dist_tensor->dist_attr(), dist_tensor);
if (dst_tensor) {
auto tensor_out = dst_tensor->impl();
PADDLE_ENFORCE_NE(
tensor_out,
nullptr,
phi::errors::InvalidArgument("The output tensor is nullptr."));
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_out.get());
dist_tensor->unsafe_set_dims(src_tensor->dims());
if (ReshardIsNeeded(src_tensor->dist_attr(), dist_tensor->dist_attr())) {
VLOG(6) << "BwdAPI KernelOut to ApiOut - "
<< ReshardDebugInfo(*src_tensor, dist_tensor->dist_attr());
auto* func = phi::distributed::ChooseProperReshardFunction(
*src_tensor, dist_tensor->dist_attr());
func->Eval(dev_ctx, *src_tensor, dist_tensor->dist_attr(), dist_tensor);
} else {
// TODO(chenweihang): add dist attr compare and default copy rule to
// avoid add branch here
// shallow copy dense tensor
*dist_tensor->unsafe_mutable_value() = src_tensor->value();
}
} else {
// TODO(chenweihang): add dist attr compare and default copy rule to
// avoid add branch here
// shallow copy dense tensor
*dist_tensor->unsafe_mutable_value() = src_tensor->value();
VLOG(3) << "The output tensor is nullptr when call "
"ReshardKernelOutputToApiOutput.";
}
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);

Expand Down
Loading