Skip to content

Commit

Permalink
refactor dist tensor constructor (PaddlePaddle#58032)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Oct 13, 2023
1 parent 03c9f2f commit 5fcf600
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 59 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/eager/grad_tensor_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id,
auto init_grad =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
auto global_dense_t =
static_cast<phi::DenseTensor*>(init_grad.impl().get());
std::static_pointer_cast<phi::DenseTensor>(init_grad.impl());
auto dist_t =
static_cast<phi::distributed::DistTensor*>(t.impl().get());
init_grad.set_impl(std::make_shared<phi::distributed::DistTensor>(
*global_dense_t, dist_t->dist_attr()));
global_dense_t, dist_t->dist_attr()));
buffer_[slot_id][rank] = init_grad;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ void CreateDistTensorWithNumpyValue(TensorObject* self,
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/CustomPlace"));
}

auto dist_tensor =
std::make_shared<phi::distributed::DistTensor>(dense_tensor, dist_attr);
auto dist_tensor = std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(dense_tensor), dist_attr);
self->tensor.set_impl(dist_tensor);

if (!autograd_meta->GetMutableGradNode()) {
Expand Down Expand Up @@ -280,13 +280,13 @@ void InitDistTensorWithTensor(TensorObject* self,
if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl());
self->tensor.set_impl(std::make_shared<DistTensor>(*tensor, dist_attr));
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
self->tensor.set_impl(std::make_shared<DistTensor>(*tensor, dist_attr));
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
VLOG(4) << "Different place, do TensorCopy for DistTensor.";
}
if (src.get_autograd_meta()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2079,9 +2079,9 @@ void DistTensorConverter::convert(Tensor* x) {
phi::distributed::TensorDistAttr dist_attr(
phi::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
auto dense_t = static_cast<phi::DenseTensor*>(x->impl().get());
auto dense_t = std::static_pointer_cast<phi::DenseTensor>(x->impl());
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(*dense_t, dist_attr));
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr));
}
}

Expand Down
12 changes: 2 additions & 10 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,7 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
phi::DenseTensor dense_t;
// TODO(GhostScreaming): polish code, dist_attr is null now
phi::distributed::TensorDistAttr dist_attr;
auto dist_t =
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr);
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
tmp->set_impl(dist_t);
}
result.emplace_back(
Expand All @@ -587,11 +583,7 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
phi::DenseTensor dense_t;
// TODO(GhostScreaming): polish code, dist_attr is null now
phi::distributed::TensorDistAttr dist_attr;
auto dist_t =
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr);
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
std::make_shared<phi::DenseTensor>(trans_in_tensor),
dist_tensor->dist_attr()));
}
} else {
out.push_back(nullptr);
Expand Down
73 changes: 43 additions & 30 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,45 @@ inline void check_defined(const DistTensor& dist_tensor,
method_hint));
}

DistTensor::DistTensor(const phi::DenseTensor& global_value,
DistTensor::DistTensor() : value_(std::make_shared<DenseTensor>()) {}

DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {
// TODO(liyurui): This is a temporary solution. We need to support only infer
// meta when the input dense_tensor is empty.
// Support the value in DistTensor only has DenseTensor meta
// but without actual data. So we can visit its meta attr even if it is
// undefined.
: dims_(global_value->dims()),
dist_attr_(dist_attr),
value_(std::make_shared<DenseTensor>()) {
// If the current rank doesn't in process_mesh, we should create an
// uninitialized tensor only with tensor_meta.
if (IsCurRankInMesh(dist_attr.process_mesh())) {
if (value_.initialized() && !dist_attr.is_replicated()) {
if (!dist_attr.is_replicated()) {
// 1. create replicated global tensor
int64_t dims_size = global_value.dims().size();
std::vector<int64_t> dims_mapping(dims_size, -1);
dist_attr_.set_dims_mapping(dims_mapping);
if (dist_attr_.is_partial()) {
dist_attr_.clean_partial_status();
}
dist_attr_.set_dims_mapping(dims_mapping);
TensorDistAttr replicated_dist_attr(vectorize(global_value->dims()));
replicated_dist_attr.set_process_mesh(dist_attr.process_mesh());
DistTensor replicated_tensor(global_value, replicated_dist_attr);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
func->Eval(dev_ctx, *this, dist_attr, this);
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place());
func->Eval(dev_ctx, replicated_tensor, dist_attr, this);
} else {
value_ = global_value;
}
} else {
// TODO(liyurui): The following logic is illegal, and should be removed
// later. It exist temporary because the basic execution procedure is not
// ready, even sometimes we try to construct a DistTensor with empty
// DistAttr. Here we warning when the DistAttr is empty for debug use.
if (dist_attr.empty()) {
LOG(WARNING) << "Try to construct a dist tensor with empty dist attr.";
}
value_ = global_value;
}
}

DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr) {}
: dims_(dims),
dist_attr_(dist_attr),
value_(std::make_shared<DenseTensor>()) {}

void DistTensor::unsafe_set_dims(const DDim& dims) {
if (this->initialized()) {
Expand All @@ -80,39 +90,42 @@ void DistTensor::unsafe_set_dist_attr(const TensorDistAttr& dist_attr) {
}

int64_t DistTensor::numel() const {
check_defined(*this, "numel");
return value_.numel();
// DistTensor with uninitialized local tensor can
// also have numel.
return product(dims_);
}

const DDim& DistTensor::local_dims() const {
check_defined(*this, "local_dims");
return value_.dims();
return value_->dims();
}

bool DistTensor::valid() const {
check_defined(*this, "valid");
return value_.valid();
return value_->valid();
}

bool DistTensor::defined() const { return value_.holder_ != nullptr; }
bool DistTensor::defined() const { return value_->holder_ != nullptr; }

bool DistTensor::initialized() const {
return value_.holder_ != nullptr && value_.holder_->ptr();
return value_->holder_ != nullptr && value_->holder_->ptr();
}

DataType DistTensor::dtype() const {
check_defined(*this, "dtype");
return value_.dtype();
// DistTensor with uninitialized local tensor can
// also have dtype.
return value_->dtype();
}

DataLayout DistTensor::layout() const {
check_defined(*this, "layout");
return value_.layout();
// DistTensor with uninitialized local tensor can
// also have layout.
return value_->layout();
}

const Place& DistTensor::place() const {
check_defined(*this, "place");
return value_.holder_->place();
return value_->holder_->place();
}

void* DistTensor::AllocateFrom(Allocator* allocator,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class DistTensor final
/// \brief Careful to create dist tensor using default constructor.
/// this should only used in reshard for now, and the dist properties
/// will be set by reshard later.
DistTensor() = default;
DistTensor();

/// \brief Construct a dist tensor based dense tensor.
/// \param global_value The global dense tensor of the current tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& global_value,
DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr);

/// \brief Construct a empty dist tensor (for infer spmd)
Expand Down Expand Up @@ -68,7 +68,7 @@ class DistTensor final

/// \brief Returns the dense tensor value's const reference in dist tensor.
/// \return The DenseTensor value's const reference
const DenseTensor& value() const { return value_; }
const DenseTensor& value() const { return *value_; }

/// \brief Returns the mutable dense tensor value in dist tensor.
/// \note If DenseTensor value is modified externally, the corresponding
Expand All @@ -77,7 +77,7 @@ class DistTensor final
/// so you need to make sure to consider it thoroughly when using
/// this method.
/// \return The mutable pointer of DenseTensor value
DenseTensor* unsafe_mutable_value() { return &value_; }
DenseTensor* unsafe_mutable_value() { return value_.get(); }

/// \brief Returns the global dims of the dist tensor.
/// \return The global dims of the dist tensor.
Expand Down Expand Up @@ -126,7 +126,7 @@ class DistTensor final
// The distributed attributes
TensorDistAttr dist_attr_;
// The local DenseTensor value
DenseTensor value_;
std::shared_ptr<DenseTensor> value_;
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ProcessMesh GetSubProcessMesh(const ProcessMesh& mesh, int64_t axis) {
for (int64_t j = static_cast<int64_t>(coord.size() - 2); j >= 0; --j) {
rank += coord[j] * mesh.dim_size(j + 1);
}
process_ids.emplace_back(rank);
process_ids.emplace_back(mesh.process_ids()[rank]);
}

ProcessMesh out_mesh(shape, process_ids, dim_names);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/reshard_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ std::shared_ptr<DistTensor> ReshardFunction::Eval(
}

void ReshardFunction::SetValue(DistTensor* tensor, const DenseTensor& value) {
tensor->value_ = value;
tensor->value_ = std::make_shared<DenseTensor>(value);
}

void ReshardFunction::SetDistProps(DistTensor* tensor,
Expand All @@ -56,7 +56,7 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,
}

DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {
return &tensor->value_;
return tensor->value_.get();
}

ReshardFunction* ChooseProperReshardFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
for (const auto& iter : p2p_pair) {
int64_t src = iter.first;
int64_t dst = iter.second;
VLOG(3) << "Send/Recv from src " << src << " to dst " << dst;
if (src == cur_global_rank) {
VLOG(3) << "Send from src " << src << " to dst " << dst;
int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst);
// Sice send kernel only has input, so we don't need to infermeta
// actually. According to this reason, just use the kernel directly.
Expand All @@ -103,6 +103,7 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
dst_local_rank,
dynamic_shape);
} else if (dst == cur_global_rank) {
VLOG(3) << "Recv from src " << src << " to dst " << dst;
int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src);
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PRecv,
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/auto_parallel/dist_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST(dist_tensor, constructor) {
dist_attr.set_process_mesh(mesh);

// copy construct
DenseTensor x1(alloc, meta);
std::shared_ptr<DenseTensor> x1 = std::make_shared<DenseTensor>(alloc, meta);
DistTensor dist_x1(x1, dist_attr);
EXPECT_TRUE(dist_x1.defined());
EXPECT_TRUE(dist_x1.initialized());
Expand Down

0 comments on commit 5fcf600

Please sign in to comment.