Skip to content

Commit

Permalink
fixed codes
Browse files Browse the repository at this point in the history
  • Loading branch information
eggman-1024 committed Nov 2, 2024
1 parent 7fe9309 commit f8034dc
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 91 deletions.
195 changes: 110 additions & 85 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -3109,111 +3109,136 @@ void kron_grad(const Tensor& x,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad && y_grad) {
// auto vector_x = split<T>(backend::flatten<T>(x, 0, -1), {x.numel()}, -1);
std::vector<Tensor> vector_x;
auto flat_x = backend::flatten<T>(x, 0, -1);
for (int i = 0; i < x.numel(); i++) {
vector_x.push_back(get_slice_vec<T>(flat_x, i, i + 1)); // 逐元素切片
}
// auto vector_y = split<T>(backend::flatten<T>(y, 0, -1), {y.numel()}, 0);
std::vector<Tensor> vector_y;
auto flat_y = backend::flatten<T>(y, 0, -1);
for (int i = 0; i < y.numel(); i++) {
vector_y.push_back(get_slice_vec<T>(flat_y, i, i + 1)); // 逐元素切片
}
// auto vector_out_grad = split<T>(backend::flatten<T>(out_grad, 0, -1),
// {out_grad.numel()}, -1);

std::vector<Tensor> vector_out_grad;
auto flat_out_grad = backend::flatten<T>(out_grad, 0, -1);
for (int i = 0; i < out_grad.numel(); i++) {
vector_out_grad.push_back(
get_slice_vec<T>(flat_out_grad, i, i + 1)); // 逐元素切片
}
if (x_grad) {
auto x_shape = x.shape();
auto y_shape = y.shape();
auto out_grad_shape = out_grad.shape();

Tensor x_ = x;
Tensor y_ = y;
int64_t x_dim_size = x.dims().size();
int64_t y_dim_size = y.dims().size();
auto diff = std::abs(x_dim_size - y_dim_size);
if (diff != 0) {
std::vector<int64_t> range_diff(diff);
for (int64_t i = 0; i < diff; i++) {
range_diff[i] = i;
}

if (x_dim_size < y_dim_size) {
x_ = unsqueeze<T>(x, IntArray(range_diff));
} else if (x_dim_size > y_dim_size) {
y_ = unsqueeze<T>(y, IntArray(range_diff));
auto diff = std::abs(static_cast<int>(x_shape.size()) -
static_cast<int>(y_shape.size()));
for (int i = 0; i < diff; i++) {
if (x_shape.size() > y_shape.size()) {
y_shape.insert(y_shape.begin(), 1);
} else {
x_shape.insert(x_shape.begin(), 1);
}
}
x_ = expand<T>(x, x_shape);
y_ = expand<T>(y, y_shape);

std::vector<Tensor> x_grad_data(x.numel(), full_scalar<T>(0.0, x.dtype()));
std::vector<Tensor> y_grad_data(y.numel(), full_scalar<T>(0.0, y.dtype()));
std::vector<int64_t> y_dim = common::vectorize<int64_t>(y_.dims());

// 计算stride信息
int64_t size = x_.shape().size();
auto x_stride = std::vector<int>(size, 0);
auto y_stride = std::vector<int>(size, 0);
auto out_grad_stride = std::vector<int>(size, 0);
// split out_grad into blocks
std::vector<int64_t> block_shape(out_grad_shape.size());
for (size_t i = 0; i < block_shape.size(); i++) {
block_shape[i] = out_grad_shape[i] / y_shape[i];
}

x_stride[size - 1] = 1;
y_stride[size - 1] = 1;
out_grad_stride[size - 1] = 1;
std::vector<int64_t> new_shape;
for (size_t i = 0; i < out_grad_shape.size(); i++) {
new_shape.push_back(block_shape[i]);
new_shape.push_back(y_shape[i]);
}
auto new_out_grad = reshape<T>(out_grad, new_shape);

for (int64_t i = size - 2; i >= 0; i--) {
x_stride[i] = x_stride[i + 1] * x_.shape()[i + 1];
y_stride[i] = y_stride[i + 1] * y_.shape()[i + 1];
out_grad_stride[i] = out_grad_stride[i + 1] * out_grad.shape()[i + 1];
// transpose out_grad
std::vector<size_t> permute_order;
for (size_t i = 0; i < out_grad_shape.size(); ++i) {
permute_order.push_back(i * 2);
}
for (size_t i = 0; i < out_grad_shape.size(); ++i) {
permute_order.push_back(i * 2 + 1);
}

// auto* out_grad_data = out_grad.data<float>();
// auto* y_data = y_.data<float>();
std::vector<Tensor> x_grad_tmp_;
std::vector<Tensor> y_grad_tmp_;
std::vector<int> permute_order_int(permute_order.begin(),
permute_order.end());
auto transposed_out_grad = transpose<T>(new_out_grad, permute_order_int);

// 计算Kronecker积的梯度
for (int64_t i = 0; i < out_grad.numel(); i++) {
auto idx = i;
auto x_idx = 0;
auto y_idx = 0;
// unsqueeze
y_dim.insert(y_dim.begin(), -1);
auto blocks = reshape<T>(transposed_out_grad, y_dim);

for (int64_t j = 0; j < size; j++) {
auto pos_j = idx / out_grad_stride[j];
idx = idx % out_grad_stride[j];
auto pos_xj = pos_j / y_.shape()[j];
auto pos_yj = pos_j % y_.shape()[j];
x_idx += x_stride[j] * pos_xj;
y_idx += y_stride[j] * pos_yj;
}
if (x_grad_data[x_idx].defined() && vector_out_grad[i].defined() &&
vector_y[y_idx].defined()) {
x_grad_data[x_idx] =
x_grad_data[x_idx] + vector_out_grad[i] * vector_y[y_idx];
} else {
PADDLE_THROW(::common::errors::InvalidArgument(
"One of the tensors in x_grad_data, vector_out_grad or vector_y is "
"not defined."));
std::vector<size_t> range;
for (size_t i = 0; i <= y_shape.size(); i++) {
if (i != 0) {
range.push_back(i);
}
}

if (y_grad_data[y_idx].defined() && vector_out_grad[i].defined() &&
vector_x[x_idx].defined()) {
y_grad_data[y_idx] =
y_grad_data[y_idx] + vector_out_grad[i] * vector_x[x_idx];
std::vector<int> range_int(range.begin(), range.end());
auto sum_tensor = sum<T>(blocks * y_, IntArray(range_int));
auto x_grad_tmp = reshape<T>(sum_tensor, x.shape());

set_output<T>(x_grad_tmp, x_grad);
}

if (y_grad) {
auto x_cast = ConverToMT<T>(x);
auto out_grad_cast = ConverToMT<T>(out_grad);

auto x_shape = x_cast.shape();
auto y_shape = y.shape();

auto diff = std::abs(static_cast<int>(x_shape.size()) -
static_cast<int>(y_shape.size()));
for (int i = 0; i < diff; i++) {
if (x_shape.size() > y_shape.size()) {
y_shape.insert(y_shape.begin(), 1);
} else {
PADDLE_THROW(::common::errors::InvalidArgument(
"One of the tensors in y_grad_data, vector_out_grad or vector_x is "
"not defined."));
x_shape.insert(x_shape.begin(), 1);
}
}

Tensor x_grad_tmp = reshape<T>(concat<T>(x_grad_data), x.shape());
Tensor y_grad_tmp = reshape<T>(concat<T>(y_grad_data), y.shape());
auto x_ = expand<T>(x_cast, x_shape);

set_output<T>(x_grad_tmp, x_grad);
set_output<T>(y_grad_tmp, y_grad);
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x_.dims());
std::vector<int64_t> out_grad_shape(out_grad_cast.shape());
Tensor out_grad_tmp = out_grad_cast;

if (x_dim.size() != 0) {
while (true) {
std::vector<int64_t> expand_shape(out_grad_tmp.shape());

int num_reduce = 0;
while (x_dim.size() != 0 && expand_shape.size() <= 8) {
int64_t repeat = x_dim.back();
int64_t orig_size = out_grad_shape.back() / repeat;
size_t out_grad_last_index = out_grad_shape.size() - 1;

expand_shape[out_grad_last_index] = repeat;
expand_shape.insert(
expand_shape.begin() + out_grad_shape.size(), 1, orig_size);

x_dim.pop_back();
out_grad_shape.pop_back();
++num_reduce;
}

int64_t axis = static_cast<int64_t>(out_grad_shape.size());
std::vector<int64_t> reduce_axes;
for (int i = 0; i < num_reduce; ++i) {
reduce_axes.push_back(axis);
axis += 2;
}

auto x_tmp_dim = common::vectorize<int64_t>(x_.dims());
for (size_t i = 0; i < x_tmp_dim.size(); ++i) {
x_ = repeat_interleave<T>(x_, y_shape[i], i);
}

out_grad_tmp = reshape<T>(out_grad_tmp * x_, expand_shape);
out_grad_tmp = sum<T>(out_grad_tmp, reduce_axes);

if (x_dim.size() == 0) {
break;
}
}
}
set_output<T>(
reshape<T>(ConverToOrig<T>(out_grad_tmp, out_grad.dtype()), y.shape()),
y_grad);
}
}

Expand Down
51 changes: 45 additions & 6 deletions test/legacy_test/test_kron_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
class TestKronOp(OpTest):
def setUp(self):
self.op_type = "kron"
self.prim_op_type = "prim"
self.python_api = paddle.kron
self.public_python_api = paddle.kron
self.dtype = self._init_dtype()
x = np.random.uniform(size=(10, 10)).astype(self.dtype)
y = np.random.uniform(size=(10, 10)).astype(self.dtype)
Expand All @@ -41,19 +43,38 @@ def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out', check_pir=True)
self.check_grad(
['X', 'Y'],
'Out',
check_pir=True,
check_prim_pir=True,
)

def test_check_grad_ignore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set('X'), check_pir=True)
self.check_grad(
['Y'],
'Out',
no_grad_set=set('X'),
check_pir=True,
check_prim_pir=True,
)

def test_check_grad_ignore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True)
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
check_pir=True,
check_prim_pir=True,
)


class TestKronOp2(TestKronOp):
def setUp(self):
self.op_type = "kron"
self.prim_op_type = "prim"
self.python_api = paddle.kron
self.public_python_api = paddle.kron
self.dtype = self._init_dtype()
x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype)
y = np.random.uniform(size=(10, 10)).astype(self.dtype)
Expand All @@ -65,7 +86,9 @@ def setUp(self):
class TestKronOp3(TestKronOp):
def setUp(self):
self.op_type = "kron"
self.prim_op_type = "prim"
self.python_api = paddle.kron
self.public_python_api = paddle.kron
self.dtype = self._init_dtype()
x = np.random.uniform(size=(10, 10)).astype(self.dtype)
y = np.random.uniform(size=(5, 5, 4)).astype(self.dtype)
Expand All @@ -87,7 +110,9 @@ def _init_dtype(self):
class TestKronBF16Op(TestKronOp):
def setUp(self):
self.op_type = "kron"
self.prim_op_type = "prim"
self.python_api = paddle.kron
self.public_python_api = paddle.kron
self.dtype = np.uint16
self.np_dtype = "float32"
x = np.random.uniform(size=(10, 10)).astype(self.np_dtype)
Expand All @@ -106,17 +131,31 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', check_pir=True
self.place,
['X', 'Y'],
'Out',
check_pir=True,
check_prim_pir=True,
)

def test_check_grad_ignore_x(self):
self.check_grad_with_place(
self.place, ['Y'], 'Out', no_grad_set=set('X'), check_pir=True
self.place,
['Y'],
'Out',
no_grad_set=set('X'),
check_pir=True,
check_prim_pir=True,
)

def test_check_grad_ignore_y(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', no_grad_set=set('Y'), check_pir=True
self.place,
['X'],
'Out',
no_grad_set=set('Y'),
check_pir=True,
check_prim_pir=True,
)


Expand Down

0 comments on commit f8034dc

Please sign in to comment.