Skip to content

Commit

Permalink
[Prim][PIR] Add back_decomp and support dynamic shape for korn_grad (#…
Browse files Browse the repository at this point in the history
…69108)

* add back_decomp for korn_grad(unfinished)

* fixed codes

* modified code and added test files

* modified test file

* modified code

* modified code

* modified kron_grad() and add some test files

* modified code of x_grad(kron_grad) in details.h

* modified details.h and add dynamic shape for kron_grad

* modified details.h
  • Loading branch information
eggman-1024 authored Nov 14, 2024
1 parent 30c0291 commit be5f0b9
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 7 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/decomp_vjp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
'fmax_grad',
'fmin_grad',
'dot_grad',
'kron_grad',
]

OTHER_PRIM_VJP_OPS = [
Expand Down
253 changes: 253 additions & 0 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,259 @@ void argsort_grad(const Tensor& indices,
}
}

template <typename T>
void kron_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
Tensor zero = full<T>({1}, 0, DataType::INT32);
Tensor x_grad_tmp;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(y.shape())) {
Tensor x_ = x;
Tensor y_ = y;
auto diff = std::abs(static_cast<int>(x.dims().size()) -
static_cast<int>(y.dims().size()));
while (diff--) {
if (x_.dims().size() > y_.dims().size()) {
y_ = backend::unsqueeze<T>(y_, zero);
} else {
x_ = backend::unsqueeze<T>(x_, zero);
}
}

// tile
std::vector<Tensor> x_shape_vec;
for (int64_t i = 0; i < x_.dims().size(); ++i) {
auto x_shape_slice = get_slice<T>(shape<T>(x_), i);
x_shape_vec.push_back(x_shape_slice);
}

auto y_tile = backend::tile<T>(y_, shape<T>(x_));

auto out_grad_tmp = y_tile * out_grad;

std::vector<Tensor> out_grad_shape_vec;
for (int64_t i = 0; i < out_grad.dims().size(); ++i) {
auto out_grad_shape_slice = get_slice<T>(shape<T>(out_grad), i);
out_grad_shape_vec.push_back(out_grad_shape_slice);
}
if (x_shape_vec.size() != 0) {
while (true) {
std::vector<Tensor> expand_shape_vec;
for (int64_t i = 0; i < out_grad_tmp.dims().size(); ++i) {
auto expand_shape = get_slice<T>(shape<T>(out_grad_tmp), i);
expand_shape_vec.push_back(expand_shape);
}
int num_reduce = 0;
while (x_shape_vec.size() != 0 && expand_shape_vec.size() <= 8) {
Tensor repeat = x_shape_vec.back();
auto orig_size =
cast<T>(out_grad_shape_vec.back() / repeat, DataType::INT32);
size_t out_grad_last_index = out_grad_shape_vec.size() - 1;
expand_shape_vec[out_grad_last_index] = repeat;
expand_shape_vec.insert(
expand_shape_vec.begin() + out_grad_shape_vec.size(),
orig_size);

x_shape_vec.pop_back();
out_grad_shape_vec.pop_back();
++num_reduce;
}

int axis = static_cast<int>(out_grad_shape_vec.size()) + 1;
std::vector<Tensor> reduce_axes_vec;
for (int i = 0; i < num_reduce; ++i) {
reduce_axes_vec.push_back(full<T>({1}, axis, DataType::INT32));
axis += 2;
}

out_grad_tmp =
backend::reshape<T>(out_grad_tmp, concat<T>(expand_shape_vec));
out_grad_tmp =
backend::sum<T>(out_grad_tmp, concat<T>(reduce_axes_vec));
if (x_shape_vec.size() == 0) {
break;
}
}
}
x_grad_tmp = backend::reshape<T>(out_grad_tmp, shape<T>(x));
} else {
auto x_shape = x.shape();
auto y_shape = y.shape();

auto diff = std::abs(static_cast<int>(x_shape.size()) -
static_cast<int>(y_shape.size()));
for (int i = 0; i < diff; i++) {
if (x_shape.size() > y_shape.size()) {
y_shape.insert(y_shape.begin(), 1);
} else {
x_shape.insert(x_shape.begin(), 1);
}
}

auto x_ = reshape<T>(x, x_shape);
auto y_ = reshape<T>(y, y_shape);

// tile
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x_.dims());
auto y_tile = tile<T>(y_, x_dim);

auto out_grad_tmp = y_tile * out_grad;

std::vector<int64_t> out_grad_shape(out_grad_tmp.shape());

if (x_dim.size() != 0) {
while (true) {
std::vector<int64_t> expand_shape(out_grad_tmp.shape());

int num_reduce = 0;
while (x_dim.size() != 0 && expand_shape.size() <= 8) {
int64_t repeat = x_dim.back();
int64_t orig_size = out_grad_shape.back() / repeat;
size_t out_grad_last_index = out_grad_shape.size() - 1;

expand_shape[out_grad_last_index] = repeat;
expand_shape.insert(
expand_shape.begin() + out_grad_shape.size(), 1, orig_size);

x_dim.pop_back();
out_grad_shape.pop_back();
++num_reduce;
}

int64_t axis = static_cast<int64_t>(out_grad_shape.size()) + 1;
std::vector<int64_t> reduce_axes;
for (int i = 0; i < num_reduce; ++i) {
reduce_axes.push_back(axis);
axis += 2;
}

out_grad_tmp = reshape<T>(out_grad_tmp, expand_shape);
out_grad_tmp = sum<T>(out_grad_tmp, reduce_axes);

if (x_dim.size() == 0) {
break;
}
}
}
x_grad_tmp = reshape<T>(out_grad_tmp, x.shape());
}
set_output<T>(x_grad_tmp, x_grad);
}
if (y_grad) {
Tensor zero = full<T>({1}, 0, DataType::INT32);
auto x_cast = ConverToMT<T>(x);
auto out_grad_cast = ConverToMT<T>(out_grad);
Tensor out_grad_tmp;
Tensor y_grad_tmp;

if (has_dynamic_shape(x_cast.shape()) || has_dynamic_shape(y.shape())) {
Tensor x_ = x_cast;
Tensor y_ = y;
auto diff = std::abs(static_cast<int>(x_cast.dims().size()) -
static_cast<int>(y.dims().size()));
while (diff--) {
if (x_.dims().size() > y_.dims().size()) {
y_ = backend::unsqueeze<T>(y_, zero);
} else {
x_ = backend::unsqueeze<T>(x_, zero);
}
}

std::vector<Tensor> x_shape_vec;
for (int64_t i = 0; i < x_.dims().size(); ++i) {
auto x_shape_slice = get_slice<T>(shape<T>(x_), i);
x_shape_vec.push_back(x_shape_slice);
}

for (int64_t i = 0; i < x_.dims().size(); ++i) {
auto y_shape_slice = get_slice<T>(shape<T>(y_), i);
auto x_shape_slice = get_slice<T>(shape<T>(x_), i);
auto y_shape_tile = backend::tile<T>(y_shape_slice, x_shape_slice);
x_ = backend::repeat_interleave_with_tensor_index<T>(
x_, y_shape_tile, i);
}
out_grad_tmp = out_grad_cast * x_;

std::vector<Tensor> out_grad_shape_vec;
for (int64_t i = 0; i < out_grad.dims().size(); ++i) {
auto out_grad_shape_slice = get_slice<T>(shape<T>(out_grad_cast), i);
out_grad_shape_vec.push_back(out_grad_shape_slice);
}

if (x_shape_vec.size() != 0) {
while (true) {
std::vector<Tensor> expand_shape_vec;
for (int64_t i = 0; i < out_grad_tmp.dims().size(); ++i) {
auto expand_shape = get_slice<T>(shape<T>(out_grad_tmp), i);
expand_shape_vec.push_back(expand_shape);
}
int num_reduce = 0;
while (x_shape_vec.size() != 0 && expand_shape_vec.size() <= 8) {
auto repeat = x_shape_vec.back();
auto orig_size =
cast<T>(out_grad_shape_vec.back() / repeat, DataType::INT32);
size_t out_grad_last_index = out_grad_shape_vec.size() - 1;
expand_shape_vec[out_grad_last_index] = repeat;
expand_shape_vec.insert(
expand_shape_vec.begin() + out_grad_shape_vec.size(),
orig_size);

x_shape_vec.pop_back();
out_grad_shape_vec.pop_back();
++num_reduce;
}
int axis = static_cast<int>(out_grad_shape_vec.size());
std::vector<Tensor> reduce_axes_vec;
for (int i = 0; i < num_reduce; ++i) {
reduce_axes_vec.push_back(full<T>({1}, axis, DataType::INT32));
axis += 2;
}
out_grad_tmp =
backend::reshape<T>(out_grad_tmp, concat<T>(expand_shape_vec));
out_grad_tmp =
backend::sum<T>(out_grad_tmp, concat<T>(reduce_axes_vec));

if (x_shape_vec.size() == 0) {
break;
}
}
}
y_grad_tmp = backend::reshape<T>(
ConverToOrig<T>(out_grad_tmp, out_grad.dtype()), shape<T>(y));
} else {
auto x_shape = x_cast.shape();
auto y_shape = y.shape();

auto diff = std::abs(static_cast<int>(x_shape.size()) -
static_cast<int>(y_shape.size()));
for (int i = 0; i < diff; i++) {
if (x_shape.size() > y_shape.size()) {
y_shape.insert(y_shape.begin(), 1);
} else {
x_shape.insert(x_shape.begin(), 1);
}
}

auto x_ = reshape<T>(x_cast, x_shape);
auto y_ = reshape<T>(y, y_shape);

std::vector<int64_t> x_dim = common::vectorize<int64_t>(x_.dims());
for (size_t i = 0; i < x_dim.size(); ++i) {
x_ = repeat_interleave<T>(x_, y_shape[i], i);
}
out_grad_tmp = out_grad_cast * x_;

tile_grad<T>(y_, out_grad_tmp, IntArray(x_dim), &y_grad_tmp);
y_grad_tmp =
reshape<T>(ConverToOrig<T>(y_grad_tmp, y.dtype()), y.shape());
}
set_output<T>(y_grad_tmp, y_grad);
}
}

} // namespace details
} // namespace primitive
} // namespace paddle
1 change: 1 addition & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"pd_op.gather_nd",
"pd_op.gelu",
"pd_op.hardswish",
"pd_op.kron",
"pd_op.kthvalue",
"pd_op.layer_norm",
"pd_op.leaky_relu",
Expand Down
Loading

0 comments on commit be5f0b9

Please sign in to comment.