Skip to content

Commit

Permalink
modified details.h
Browse files Browse the repository at this point in the history
  • Loading branch information
eggman-1024 committed Nov 13, 2024
1 parent 93da659 commit ac75583
Showing 1 changed file with 4 additions and 38 deletions.
42 changes: 4 additions & 38 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -3149,51 +3149,17 @@ void kron_grad(const Tensor& x,
}

auto x_ = reshape<T>(x_cast, x_shape);
auto y_ = reshape<T>(y, y_shape);

std::vector<int64_t> x_dim = common::vectorize<int64_t>(x_.dims());
for (size_t i = 0; i < x_dim.size(); ++i) {
x_ = repeat_interleave<T>(x_, y_shape[i], i);
}
out_grad_tmp = out_grad_cast * x_;

std::vector<int64_t> out_grad_shape(out_grad_cast.shape());

if (x_dim.size() != 0) {
while (true) {
std::vector<int64_t> expand_shape(out_grad_tmp.shape());

int num_reduce = 0;
while (x_dim.size() != 0 && expand_shape.size() <= 8) {
int64_t repeat = x_dim.back();
int64_t orig_size = out_grad_shape.back() / repeat;
size_t out_grad_last_index = out_grad_shape.size() - 1;

expand_shape[out_grad_last_index] = repeat;
expand_shape.insert(
expand_shape.begin() + out_grad_shape.size(), 1, orig_size);

x_dim.pop_back();
out_grad_shape.pop_back();
++num_reduce;
}

int64_t axis = static_cast<int64_t>(out_grad_shape.size());
std::vector<int64_t> reduce_axes;
for (int i = 0; i < num_reduce; ++i) {
reduce_axes.push_back(axis);
axis += 2;
}

out_grad_tmp = reshape<T>(out_grad_tmp, expand_shape);
out_grad_tmp = sum<T>(out_grad_tmp, reduce_axes);

if (x_dim.size() == 0) {
break;
}
}
}
y_grad_tmp = reshape<T>(ConverToOrig<T>(out_grad_tmp, out_grad.dtype()),
y.shape());
tile_grad<T>(y_, out_grad_tmp, IntArray(x_dim), &y_grad_tmp);
y_grad_tmp =
reshape<T>(ConverToOrig<T>(y_grad_tmp, y.dtype()), y.shape());
}
set_output<T>(y_grad_tmp, y_grad);
}
Expand Down

0 comments on commit ac75583

Please sign in to comment.