Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Support the backward rule of elementwise binary #57813

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,18 +652,9 @@ ReshardApiInputToReplicatedKernelInput(
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) {
VLOG(6) << "ApiIn to Replicated KernelIn - "
<< ReshardDebugInfo(*dist_tensor, dist_attr);
if (dist_tensor->initialized()) {
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
return func->Eval(dev_ctx, *dist_tensor, dist_attr);
} else {
// when no tensor data need to be reshard, we still need to set correct
// replicated dist attr and local dims for output
dist_tensor->unsafe_set_dist_attr(dist_attr);
auto dense_tensor_meta = dist_tensor->value().meta();
dense_tensor_meta.dims = dist_tensor->dims();
dist_tensor->unsafe_mutable_value()->set_meta(dense_tensor_meta);
}
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
return func->Eval(dev_ctx, *dist_tensor, dist_attr);
}
return std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
spmd_rule : ElementwiseBinaryGradInferSpmd
kernel :
func : add_grad
no_need_buffer : x, y
Expand Down Expand Up @@ -680,6 +681,7 @@
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
spmd_rule : ElementwiseBinaryGradInferSpmd
kernel :
func : subtract_grad
no_need_buffer : x, y
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
spmd_rule : ElementwiseBinaryInferSpmd
kernel :
func : add
inplace : (x -> out)
Expand Down Expand Up @@ -1003,6 +1004,7 @@
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
spmd_rule : ElementwiseBinaryInferSpmd
kernel :
func : subtract
inplace : (x -> out)
Expand Down
67 changes: 67 additions & 0 deletions paddle/phi/infermeta/spmd_rules/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,5 +309,72 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,
return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out_grad,
int64_t axis) {
TensorDistAttr x_dist_attr = out_grad.dist_attr();
TensorDistAttr y_dist_attr = out_grad.dist_attr();
TensorDistAttr x_grad_dist_attr = out_grad.dist_attr();
TensorDistAttr y_grad_dist_attr = out_grad.dist_attr();

PADDLE_ENFORCE_GE(
out_grad.dims().size(),
x.dims().size(),
phi::errors::InvalidArgument("If being broadcast, the dims of out_grad "
"must larger or equal to the inputs."
"But we get the rank of output as [%d] and "
"the rank of input as [%d].",
out_grad.dims().size(),
x.dims().size()));

PADDLE_ENFORCE_GE(
out_grad.dims().size(),
y.dims().size(),
phi::errors::InvalidArgument("If being broadcast, the dims of out_grad "
"must larger or equal to the inputs."
"But we get the rank of output as [%d] and "
"the rank of input as [%d].",
out_grad.dims().size(),
y.dims().size()));

// The backward rule of elementwise follows the princple: the dist_attr
// of input should equal to out_grad.
// Caution the special case when the inputs calculate together with different
// shape it means one of the input is broadcast to same shape with the other
// first. When doing backward the input_grad with broadcast input is in
// partial status, which need to do communicate and get the right result.
if (x.dims() != out_grad.dims()) {
int64_t diff = out_grad.dims().size() - x.dims().size();
auto dims_mapping = x_dist_attr.dims_mapping();
dims_mapping.erase(dims_mapping.begin(), dims_mapping.begin() + diff);
x_dist_attr.set_dims_mapping(dims_mapping);
x_grad_dist_attr.set_dims_mapping(dims_mapping);
for (int64_t i = 0; i < diff; ++i) {
if (out_grad.dist_attr().dims_mapping()[i] != -1) {
x_grad_dist_attr.set_partial_status(
std::vector<int64_t>{out_grad.dist_attr().dims_mapping()[i]});
}
}
}

if (y.dims() != out_grad.dims()) {
int64_t diff = out_grad.dims().size() - y.dims().size();
auto dims_mapping = y_dist_attr.dims_mapping();
dims_mapping.erase(dims_mapping.begin(), dims_mapping.begin() + diff);
y_dist_attr.set_dims_mapping(dims_mapping);
y_grad_dist_attr.set_dims_mapping(dims_mapping);
for (int64_t i = 0; i < diff; ++i) {
if (out_grad.dist_attr().dims_mapping()[i] != -1) {
y_grad_dist_attr.set_partial_status(
std::vector<int64_t>{out_grad.dist_attr().dims_mapping()[i]});
}
}
}

return {{x_dist_attr, y_dist_attr, out_grad.dist_attr()},
{x_grad_dist_attr, y_grad_dist_attr}};
}

} // namespace distributed
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/spmd_rules/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out);

SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out_grad,
int64_t axis);

} // namespace distributed
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/all_gather_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ void AllGather(const Context& dev_ctx,
MetaTensor* out_meta_ptr = &out_meta;

AllGatherInferMeta(phi::MetaTensor(x), nranks, out_meta_ptr);
AllGatherKernel<T, Context>(dev_ctx, x, nranks, out);
if (x.initialized()) {
AllGatherKernel<T, Context>(dev_ctx, x, nranks, out);
}
}

} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/all_reduce_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ void AllReduce(const Context& dev_ctx,
MetaTensor* out_meta_ptr = &out_meta;

AllReduceInferMeta(phi::MetaTensor(x), out_meta_ptr);
AllReduceKernel<T, Context>(dev_ctx, x, reduce_type, out);
if (x.initialized()) {
AllReduceKernel<T, Context>(dev_ctx, x, reduce_type, out);
}
}

} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/all_to_all_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ void AllToAll(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
MetaTensor* out_meta_ptr = &out_meta;

AllToAllInferMeta(phi::MetaTensor(x), out_meta_ptr);
AllToAllKernel<T, Context>(dev_ctx, x, out);
if (x.initialized()) {
AllToAllKernel<T, Context>(dev_ctx, x, out);
}
}

} // namespace phi
5 changes: 4 additions & 1 deletion paddle/phi/kernels/concat_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ void Concat(const Context& dev_ctx,

MetaTensor meta_out(dense_out);
ConcatInferMeta(meta_x_ptr, axis.to<int>(), &meta_out);
ConcatKernel<T, Context>(dev_ctx, x, axis, dense_out);

if (x[0]->initialized()) {
ConcatKernel<T, Context>(dev_ctx, x, axis, dense_out);
}
}

template <typename T, typename Context>
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/svd_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ void SvdKernel(const Context& dev_ctx,
/*Create Tensors and output, set the dim ...*/
auto numel = X.numel();
DenseTensor trans_x = ::phi::TransposeLast2Dim<T>(dev_ctx, X);
auto* x_data = trans_x.data<T>();
auto x_dims = X.dims();
int rows = static_cast<int>(x_dims[x_dims.size() - 2]);
int cols = static_cast<int>(x_dims[x_dims.size() - 1]);
Expand All @@ -113,6 +112,7 @@ void SvdKernel(const Context& dev_ctx,
0,
cols,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
auto* x_data = trans_x.data<T>();
int batches = static_cast<int>(numel / (rows * cols));
auto* U_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* VH_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/reshape_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ void Reshape(const Context& dev_ctx,
DenseTensor* out) {
MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape, &meta_out);
ReshapeInferKernel<Context>(dev_ctx, x, IntArray(shape), out);
if (x.initialized()) {
ReshapeInferKernel<Context>(dev_ctx, x, IntArray(shape), out);
}
}

template <typename T, typename Context>
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/split_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ void Split(const Context& dev_ctx,
outs.push_back(&result->at(i));
}

SplitKernel<T, Context>(dev_ctx, x, sections, axis, outs);
if (x.initialized()) {
SplitKernel<T, Context>(dev_ctx, x, sections, axis, outs);
}
}

template <typename T, typename Context>
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/transpose_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/empty_kernel.h"

namespace phi {

template <typename T, typename Context>
Expand All @@ -43,7 +44,9 @@ void Transpose(const Context& dev_ctx,

// do not call TransposeStridedKernel, because some other kernels call
// Transpose directly
TransposeKernel<T, Context>(dev_ctx, x, axis, dense_out);
if (x.initialized()) {
TransposeKernel<T, Context>(dev_ctx, x, axis, dense_out);
}
}

template <typename T, typename Context>
Expand Down
147 changes: 147 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist


class TestElementwiseApiForSemiAutoParallel:
def __init__(self):
self._dtype = os.getenv("dtype")
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def check_tensor_eq(self, a, b):
np1 = a.numpy()
np2 = b.numpy()
np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True)

def test_binary_body(
self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func
):
paddle.seed(self._seed)
np.random.seed(self._seed)

x = paddle.randn(x_shape, self._dtype)
y = paddle.randn(y_shape, self._dtype)
x.stop_gradient = False
y.stop_gradient = False

x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs)
y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs)

dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr)
dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr)
dist_x.stop_gradient = False
dist_y.stop_gradient = False

dist_out = binary_func(dist_x, dist_y)
out = binary_func(x, y)
self.check_tensor_eq(out, dist_out)

dist_out.backward()
out.backward()
self.check_tensor_eq(x.grad, dist_x.grad)
self.check_tensor_eq(y.grad, dist_y.grad)

def test_add_x_shard(self):
self.test_binary_body(
x_shape=[16, 32],
y_shape=[16, 32],
out_shape=[16, 32],
x_specs=['x', None],
y_specs=[None, None],
binary_func=paddle.add,
)

def test_sub_x_shard(self):
self.test_binary_body(
x_shape=[16, 32],
y_shape=[16, 32],
out_shape=[16, 32],
x_specs=['x', None],
y_specs=[None, None],
binary_func=paddle.subtract,
)

def test_add_x_shard_broadcast(self):
self.test_binary_body(
x_shape=[16, 32],
y_shape=[2, 16, 32],
out_shape=[2, 16, 32],
x_specs=['x', None],
y_specs=[None, None, None],
binary_func=paddle.add,
)

def test_add_x_y_shard(self):
if self._backend == "cpu":
return

self.test_binary_body(
x_shape=[16, 32],
y_shape=[16, 32],
out_shape=[16, 32],
x_specs=['x', None],
y_specs=[None, 'x'],
binary_func=paddle.add,
)

def test_add_x_y_shard_broadcast(self):
if self._backend == "cpu":
return

self.test_binary_body(
x_shape=[4, 16, 32],
y_shape=[16, 32],
out_shape=[4, 16, 32],
x_specs=['x', None, None],
y_specs=[None, None],
binary_func=paddle.add,
)

def test_sub_x_y_shard_broadcast(self):
if self._backend == "cpu":
return

self.test_binary_body(
x_shape=[4, 16, 32],
y_shape=[16, 32],
out_shape=[4, 16, 32],
x_specs=['x', None, None],
y_specs=[None, None],
binary_func=paddle.subtract,
)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")

self.test_add_x_shard()
self.test_add_x_shard_broadcast()
self.test_add_x_y_shard()
self.test_add_x_y_shard_broadcast()


if __name__ == '__main__':
TestElementwiseApiForSemiAutoParallel().run_test_case()
Loading