Skip to content

Commit

Permalink
[CodeStyle][Typos][C-52] Fix typo (context) (PaddlePaddle#69839)
Browse files Browse the repository at this point in the history
* context

* context
  • Loading branch information
Neo-WY authored and zhiqiu committed Dec 3, 2024
1 parent 6043c7f commit 131ea5e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
7 changes: 7 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,13 @@ void SetReplicatedDistAttrForOutput(
phi::distributed::DistTensor* out,
const phi::distributed::ProcessMesh& process_mesh) {
if (out) {
if (out->dims().size() == -1 || out->dims().size() == 0) {
if (out->local_dims().size() != -1 && out->local_dims().size() != 0) {
out->unsafe_set_dims(out->local_dims());
VLOG(3)
<< "DistTensor out has empty shape, use its local value's shape";
}
}
// For inplace output, we also need to set replicated dist attr
auto dist_attr =
phi::distributed::TensorDistAttr(common::vectorize(out->dims()));
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx,
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
VLOG(4) << "ReshardIsNeededWithPartial"
VLOG(4) << "ReshardIsNeededWithPartial "
<< ReshardIsNeededWithPartial(dist_tensor->dist_attr(),
dist_attr);
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(), dist_attr)) {
Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/kernels/impl/einsum_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,12 @@ DenseTensor PerformContraction(
label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, label2type);
if (cache[operand_idx] != nullptr)
if (cache[operand_idx] != nullptr) {
cache[operand_idx]->ShareBufferWith(trans_t);
cache[operand_idx]->Resize(trans_t.dims());
VLOG(5) << "Set dims of cache[" << operand_idx
<< "]: " << trans_t.dims();
}
}
auto mul_dims = GetShapeByType<int>(
all_labels, label2type, perm, label2shape, {LabelType::Batch});
Expand Down

0 comments on commit 131ea5e

Please sign in to comment.