From f8034dc7cc0cfc4b8d86e6a926117b49cf4765eb Mon Sep 17 00:00:00 2001 From: ruitao zhang <3160391266@qq.com> Date: Sat, 2 Nov 2024 11:47:40 +0000 Subject: [PATCH] fixed codes --- paddle/fluid/primitive/rule/vjp/details.h | 195 ++++++++++++---------- test/legacy_test/test_kron_op.py | 51 +++++- 2 files changed, 155 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 99e18bc089e28a..68684bb1015d42 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -3109,111 +3109,136 @@ void kron_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad, Tensor* y_grad) { - if (x_grad && y_grad) { - // auto vector_x = split(backend::flatten(x, 0, -1), {x.numel()}, -1); - std::vector vector_x; - auto flat_x = backend::flatten(x, 0, -1); - for (int i = 0; i < x.numel(); i++) { - vector_x.push_back(get_slice_vec(flat_x, i, i + 1)); // 逐元素切片 - } - // auto vector_y = split(backend::flatten(y, 0, -1), {y.numel()}, 0); - std::vector vector_y; - auto flat_y = backend::flatten(y, 0, -1); - for (int i = 0; i < y.numel(); i++) { - vector_y.push_back(get_slice_vec(flat_y, i, i + 1)); // 逐元素切片 - } - // auto vector_out_grad = split(backend::flatten(out_grad, 0, -1), - // {out_grad.numel()}, -1); - - std::vector vector_out_grad; - auto flat_out_grad = backend::flatten(out_grad, 0, -1); - for (int i = 0; i < out_grad.numel(); i++) { - vector_out_grad.push_back( - get_slice_vec(flat_out_grad, i, i + 1)); // 逐元素切片 - } + if (x_grad) { + auto x_shape = x.shape(); + auto y_shape = y.shape(); + auto out_grad_shape = out_grad.shape(); Tensor x_ = x; Tensor y_ = y; - int64_t x_dim_size = x.dims().size(); - int64_t y_dim_size = y.dims().size(); - auto diff = std::abs(x_dim_size - y_dim_size); - if (diff != 0) { - std::vector range_diff(diff); - for (int64_t i = 0; i < diff; i++) { - range_diff[i] = i; - } - - if (x_dim_size < y_dim_size) { - x_ = unsqueeze(x, IntArray(range_diff)); - } else if (x_dim_size > y_dim_size) { - y_ = unsqueeze(y, IntArray(range_diff)); + auto diff = std::abs(static_cast(x_shape.size()) - + static_cast(y_shape.size())); + for (int i = 0; i < diff; i++) { + if (x_shape.size() > y_shape.size()) { + y_shape.insert(y_shape.begin(), 1); + } else { + x_shape.insert(x_shape.begin(), 1); } } + x_ = expand(x, x_shape); + y_ = expand(y, y_shape); - std::vector x_grad_data(x.numel(), full_scalar(0.0, x.dtype())); - std::vector y_grad_data(y.numel(), full_scalar(0.0, y.dtype())); + std::vector y_dim = common::vectorize(y_.dims()); - // 计算stride信息 - int64_t size = x_.shape().size(); - auto x_stride = std::vector(size, 0); - auto y_stride = std::vector(size, 0); - auto out_grad_stride = std::vector(size, 0); + // split out_grad into blocks + std::vector block_shape(out_grad_shape.size()); + for (size_t i = 0; i < block_shape.size(); i++) { + block_shape[i] = out_grad_shape[i] / y_shape[i]; + } - x_stride[size - 1] = 1; - y_stride[size - 1] = 1; - out_grad_stride[size - 1] = 1; + std::vector new_shape; + for (size_t i = 0; i < out_grad_shape.size(); i++) { + new_shape.push_back(block_shape[i]); + new_shape.push_back(y_shape[i]); + } + auto new_out_grad = reshape(out_grad, new_shape); - for (int64_t i = size - 2; i >= 0; i--) { - x_stride[i] = x_stride[i + 1] * x_.shape()[i + 1]; - y_stride[i] = y_stride[i + 1] * y_.shape()[i + 1]; - out_grad_stride[i] = out_grad_stride[i + 1] * out_grad.shape()[i + 1]; + // transpose out_grad + std::vector permute_order; + for (size_t i = 0; i < out_grad_shape.size(); ++i) { + permute_order.push_back(i * 2); + } + for (size_t i = 0; i < out_grad_shape.size(); ++i) { + permute_order.push_back(i * 2 + 1); } - // auto* out_grad_data = out_grad.data(); - // auto* y_data = y_.data(); - std::vector x_grad_tmp_; - std::vector y_grad_tmp_; + std::vector permute_order_int(permute_order.begin(), + permute_order.end()); + auto transposed_out_grad = transpose(new_out_grad, permute_order_int); - // 计算Kronecker积的梯度 - for (int64_t i = 0; i < out_grad.numel(); i++) { - auto idx = i; - auto x_idx = 0; - auto y_idx = 0; + // unsqueeze + y_dim.insert(y_dim.begin(), -1); + auto blocks = reshape(transposed_out_grad, y_dim); - for (int64_t j = 0; j < size; j++) { - auto pos_j = idx / out_grad_stride[j]; - idx = idx % out_grad_stride[j]; - auto pos_xj = pos_j / y_.shape()[j]; - auto pos_yj = pos_j % y_.shape()[j]; - x_idx += x_stride[j] * pos_xj; - y_idx += y_stride[j] * pos_yj; - } - if (x_grad_data[x_idx].defined() && vector_out_grad[i].defined() && - vector_y[y_idx].defined()) { - x_grad_data[x_idx] = - x_grad_data[x_idx] + vector_out_grad[i] * vector_y[y_idx]; - } else { - PADDLE_THROW(::common::errors::InvalidArgument( - "One of the tensors in x_grad_data, vector_out_grad or vector_y is " - "not defined.")); + std::vector range; + for (size_t i = 0; i <= y_shape.size(); i++) { + if (i != 0) { + range.push_back(i); } + } - if (y_grad_data[y_idx].defined() && vector_out_grad[i].defined() && - vector_x[x_idx].defined()) { - y_grad_data[y_idx] = - y_grad_data[y_idx] + vector_out_grad[i] * vector_x[x_idx]; + std::vector range_int(range.begin(), range.end()); + auto sum_tensor = sum(blocks * y_, IntArray(range_int)); + auto x_grad_tmp = reshape(sum_tensor, x.shape()); + + set_output(x_grad_tmp, x_grad); + } + + if (y_grad) { + auto x_cast = ConverToMT(x); + auto out_grad_cast = ConverToMT(out_grad); + + auto x_shape = x_cast.shape(); + auto y_shape = y.shape(); + + auto diff = std::abs(static_cast(x_shape.size()) - + static_cast(y_shape.size())); + for (int i = 0; i < diff; i++) { + if (x_shape.size() > y_shape.size()) { + y_shape.insert(y_shape.begin(), 1); } else { - PADDLE_THROW(::common::errors::InvalidArgument( - "One of the tensors in y_grad_data, vector_out_grad or vector_x is " - "not defined.")); + x_shape.insert(x_shape.begin(), 1); } } - Tensor x_grad_tmp = reshape(concat(x_grad_data), x.shape()); - Tensor y_grad_tmp = reshape(concat(y_grad_data), y.shape()); + auto x_ = expand(x_cast, x_shape); - set_output(x_grad_tmp, x_grad); - set_output(y_grad_tmp, y_grad); + std::vector x_dim = common::vectorize(x_.dims()); + std::vector out_grad_shape(out_grad_cast.shape()); + Tensor out_grad_tmp = out_grad_cast; + + if (x_dim.size() != 0) { + while (true) { + std::vector expand_shape(out_grad_tmp.shape()); + + int num_reduce = 0; + while (x_dim.size() != 0 && expand_shape.size() <= 8) { + int64_t repeat = x_dim.back(); + int64_t orig_size = out_grad_shape.back() / repeat; + size_t out_grad_last_index = out_grad_shape.size() - 1; + + expand_shape[out_grad_last_index] = repeat; + expand_shape.insert( + expand_shape.begin() + out_grad_shape.size(), 1, orig_size); + + x_dim.pop_back(); + out_grad_shape.pop_back(); + ++num_reduce; + } + + int64_t axis = static_cast(out_grad_shape.size()); + std::vector reduce_axes; + for (int i = 0; i < num_reduce; ++i) { + reduce_axes.push_back(axis); + axis += 2; + } + + auto x_tmp_dim = common::vectorize(x_.dims()); + for (size_t i = 0; i < x_tmp_dim.size(); ++i) { + x_ = repeat_interleave(x_, y_shape[i], i); + } + + out_grad_tmp = reshape(out_grad_tmp * x_, expand_shape); + out_grad_tmp = sum(out_grad_tmp, reduce_axes); + + if (x_dim.size() == 0) { + break; + } + } + } + set_output( + reshape(ConverToOrig(out_grad_tmp, out_grad.dtype()), y.shape()), + y_grad); } } diff --git a/test/legacy_test/test_kron_op.py b/test/legacy_test/test_kron_op.py index 2026bf26b42210..f928c2e86afbbf 100644 --- a/test/legacy_test/test_kron_op.py +++ b/test/legacy_test/test_kron_op.py @@ -26,7 +26,9 @@ class TestKronOp(OpTest): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(10, 10)).astype(self.dtype) y = np.random.uniform(size=(10, 10)).astype(self.dtype) @@ -41,19 +43,38 @@ def test_check_output(self): self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out', check_pir=True) + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + check_prim_pir=True, + ) def test_check_grad_ignore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set('X'), check_pir=True) + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set('X'), + check_pir=True, + check_prim_pir=True, + ) def test_check_grad_ignore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + check_pir=True, + check_prim_pir=True, + ) class TestKronOp2(TestKronOp): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) y = np.random.uniform(size=(10, 10)).astype(self.dtype) @@ -65,7 +86,9 @@ def setUp(self): class TestKronOp3(TestKronOp): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(10, 10)).astype(self.dtype) y = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) @@ -87,7 +110,9 @@ def _init_dtype(self): class TestKronBF16Op(TestKronOp): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = np.uint16 self.np_dtype = "float32" x = np.random.uniform(size=(10, 10)).astype(self.np_dtype) @@ -106,17 +131,31 @@ def test_check_output(self): def test_check_grad(self): self.check_grad_with_place( - self.place, ['X', 'Y'], 'Out', check_pir=True + self.place, + ['X', 'Y'], + 'Out', + check_pir=True, + check_prim_pir=True, ) def test_check_grad_ignore_x(self): self.check_grad_with_place( - self.place, ['Y'], 'Out', no_grad_set=set('X'), check_pir=True + self.place, + ['Y'], + 'Out', + no_grad_set=set('X'), + check_pir=True, + check_prim_pir=True, ) def test_check_grad_ignore_y(self): self.check_grad_with_place( - self.place, ['X'], 'Out', no_grad_set=set('Y'), check_pir=True + self.place, + ['X'], + 'Out', + no_grad_set=set('Y'), + check_pir=True, + check_prim_pir=True, )