Skip to content

Commit

Permalink
[AutoParallel] Generate replicated spmd for PHI API and verify DP MP …
Browse files Browse the repository at this point in the history
…strategy (PaddlePaddle#57505)

* generate forward defalut spmd

* generate bwd default spmd rule

* test relu and mse forward success

* test mse loss fwd and bwd

* updarte replicated rule name

* update single strategy test

* add unittests

* polish details

* remove useless seed

* fix dist branch test error
  • Loading branch information
chenwhql authored and Frida-a committed Oct 14, 2023
1 parent 40f61de commit 0dcb84d
Show file tree
Hide file tree
Showing 25 changed files with 799 additions and 218 deletions.
13 changes: 7 additions & 6 deletions paddle/fluid/eager/tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ class TensorWrapper {
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
dense_tensor->meta()));
} else if (phi::distributed::DistTensor::classof(tensor.impl().get())) {
// Only Copy Meta
// Copy Global dims, DistAttr and DenseTensorMeta
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor.impl().get());
intermidiate_tensor_.set_impl(
auto no_buffer_dist_tensor =
std::make_shared<phi::distributed::DistTensor>(
phi::DenseTensor(std::make_shared<phi::Allocation>(
nullptr, 0, tensor.place()),
dist_tensor->value().meta()),
dist_tensor->dist_attr()));
dist_tensor->dims(), dist_tensor->dist_attr());
*no_buffer_dist_tensor->unsafe_mutable_value() = phi::DenseTensor(
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
dist_tensor->value().meta());
intermidiate_tensor_.set_impl(no_buffer_dist_tensor);
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
Expand Down
32 changes: 22 additions & 10 deletions paddle/fluid/eager/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ class EagerUtils {
"Type: %s, Dtype: %s, Place: %s, Shape: %s, DistAttr: %s";
std::string tensor_info_str = "";
if (t.defined()) {
if (t.initialized()) {
if (t.is_dist_tensor()) {
auto dist_t =
std::static_pointer_cast<phi::distributed::DistTensor>(t.impl());
if (t.is_dist_tensor()) {
auto dist_t =
std::static_pointer_cast<phi::distributed::DistTensor>(t.impl());
if (t.initialized()) {
tensor_info_str += paddle::string::Sprintf(
TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
Expand All @@ -321,20 +321,32 @@ class EagerUtils {
"%s, Local Shape: %s", t.dims(), dist_t->local_dims()),
dist_t->dist_attr());
} else {
tensor_info_str +=
paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
"Unknown",
"Unknown",
t.dims(),
dist_t->dist_attr());
}
} else {
if (t.initialized()) {
tensor_info_str +=
paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
t.dtype(),
t.place().DebugString(),
t.dims(),
"Unknown");
} else {
tensor_info_str +=
paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
"Unknown",
"Unknown",
"Unknown",
"Unknown");
}
} else {
tensor_info_str += paddle::string::Sprintf(TENSOR_INFO_TEMPLATE,
t.impl()->type_info().name(),
"Unknown",
"Unknown",
"Unknown");
}
} else {
tensor_info_str += "Unknown";
Expand Down
18 changes: 16 additions & 2 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,12 @@ phi::distributed::DistTensor* SetKernelDistOutput(
}

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
const phi::distributed::TensorDistAttr& dist_attr) {
return std::make_shared<phi::distributed::DistTensor>(phi::DDim(), dist_attr);
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
if (out) {
return std::make_shared<phi::distributed::DistTensor>(phi::DDim(),
dist_attr);
}
return nullptr;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
Expand Down Expand Up @@ -617,6 +621,16 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
}
return results;
}
void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh) {
if (out) {
auto dist_attr =
phi::distributed::TensorDistAttr(phi::vectorize(out->dims()));
dist_attr.set_process_mesh(process_mesh);
out->unsafe_set_dist_attr(dist_attr);
}
}

} // namespace experimental
} // namespace paddle
11 changes: 10 additions & 1 deletion paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ phi::distributed::DistTensor* SetKernelDistOutput(
phi::distributed::TensorDistAttr());

std::shared_ptr<phi::distributed::DistTensor> CreateKernelDistOutput(
const phi::distributed::TensorDistAttr& dist_attr);
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
phi::distributed::TensorDistAttr());

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
std::vector<Tensor*> out);
Expand All @@ -159,5 +161,12 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOutput(
std::vector<phi::distributed::DistTensor*> SetKernelDistInplaceOptionalOutput(
size_t out_size, paddle::optional<std::vector<Tensor>> out);

// DistTensor need to set initial dist attr after the dims setted, it is
// constructed based dims and current process mesh, beforce calling this
// function, the out should hold correct dims
void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh);

} // namespace experimental
} // namespace paddle
83 changes: 63 additions & 20 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,14 @@ void TransDataBackend(const phi::SelectedRows* tensor,

/* ------------------ for auto parallel ----------------------- */

static bool ReshardIsNeeded(
const phi::distributed::TensorDistAttr& in_dist_attr,
const phi::distributed::TensorDistAttr& out_dist_attr) {
return (in_dist_attr.process_mesh() != out_dist_attr.process_mesh() ||
in_dist_attr.dims_mapping() != out_dist_attr.dims_mapping() ||
in_dist_attr.partial_status() != out_dist_attr.partial_status());
}

std::string ReshardDebugInfo(
const phi::distributed::DistTensor& src_tensor,
const phi::distributed::TensorDistAttr& dist_attr) {
Expand All @@ -620,8 +628,8 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->dist_attr() != dist_attr) {
VLOG(6) << "FwdAPI ApiIn to KernelIn - "
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
VLOG(6) << "ApiIn to KernelIn - "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
Expand All @@ -632,6 +640,36 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
return nullptr;
}

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
VLOG(6) << "ApiIn to Replicated KernelIn - "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
if (dist_tensor->initialized()) {
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
return func->Eval(dev_ctx, *dist_tensor, dist_attr);
} else {
// when no tensor data need to be reshard, we still need to set correct
// replicated dist attr and local dims for output
dist_tensor->unsafe_set_dist_attr(dist_attr);
auto dense_tensor_meta = dist_tensor->value().meta();
dense_tensor_meta.dims = dist_tensor->dims();
dist_tensor->unsafe_mutable_value()->set_meta(dense_tensor_meta);
}
}
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
return nullptr;
}

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
if (out_tensor->dist_attr().is_partial()) {
Expand All @@ -649,25 +687,30 @@ void ReshardKernelOutputToApiOutput(
phi::DeviceContext* dev_ctx,
const std::shared_ptr<phi::distributed::DistTensor>& src_tensor,
Tensor* dst_tensor) {
auto tensor_out = dst_tensor->impl();
PADDLE_ENFORCE_NE(
tensor_out,
nullptr,
phi::errors::InvalidArgument("The output tensor is nullptr."));
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_out.get());
dist_tensor->unsafe_set_dims(src_tensor->dims());
if (src_tensor->dist_attr() != dist_tensor->dist_attr()) {
VLOG(6) << "BwdAPI KernelOut to ApiOut - "
<< ReshardDebugInfo(*src_tensor, dist_tensor->dist_attr());
auto* func = phi::distributed::ChooseProperReshardFunction(
*src_tensor, dist_tensor->dist_attr());
func->Eval(dev_ctx, *src_tensor, dist_tensor->dist_attr(), dist_tensor);
if (dst_tensor) {
auto tensor_out = dst_tensor->impl();
PADDLE_ENFORCE_NE(
tensor_out,
nullptr,
phi::errors::InvalidArgument("The output tensor is nullptr."));
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_out.get());
dist_tensor->unsafe_set_dims(src_tensor->dims());
if (ReshardIsNeeded(src_tensor->dist_attr(), dist_tensor->dist_attr())) {
VLOG(6) << "BwdAPI KernelOut to ApiOut - "
<< ReshardDebugInfo(*src_tensor, dist_tensor->dist_attr());
auto* func = phi::distributed::ChooseProperReshardFunction(
*src_tensor, dist_tensor->dist_attr());
func->Eval(dev_ctx, *src_tensor, dist_tensor->dist_attr(), dist_tensor);
} else {
// TODO(chenweihang): add dist attr compare and default copy rule to
// avoid add branch here
// shallow copy dense tensor
*dist_tensor->unsafe_mutable_value() = src_tensor->value();
}
} else {
// TODO(chenweihang): add dist attr compare and default copy rule to
// avoid add branch here
// shallow copy dense tensor
*dist_tensor->unsafe_mutable_value() = src_tensor->value();
VLOG(3) << "The output tensor is nullptr when call "
"ReshardKernelOutputToApiOutput.";
}
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

std::shared_ptr<phi::distributed::DistTensor>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const Tensor& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);

Expand Down
Loading

0 comments on commit 0dcb84d

Please sign in to comment.