Skip to content

Commit

Permalink
modified code and added test files
Browse files Browse the repository at this point in the history
  • Loading branch information
eggman-1024 committed Nov 5, 2024
1 parent 56240f5 commit 5554d08
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 32 deletions.
60 changes: 29 additions & 31 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -3121,44 +3121,42 @@ void kron_grad(const Tensor& x,
for (int i = 0; i < diff; i++) {
if (x_shape.size() > y_shape.size()) {
y_shape.insert(y_shape.begin(), 1);
y_ = expand<T>(y, y_shape);
} else {
x_shape.insert(x_shape.begin(), 1);
x_ = expand<T>(x, x_shape);
}
}
x_ = expand<T>(x, x_shape);
y_ = expand<T>(y, y_shape);
// x_ = expand<T>(x, x_shape);
// y_ = expand<T>(y, y_shape);

// unsqueeze
std::vector<int64_t> y_dim = common::vectorize<int64_t>(y_.dims());
y_dim.insert(y_dim.begin(), -1);

// split out_grad into blocks
std::vector<int64_t> block_shape(out_grad_shape.size());
for (size_t i = 0; i < block_shape.size(); i++) {
block_shape[i] = out_grad_shape[i] / y_shape[i];
}

std::vector<int64_t> new_shape;
for (size_t i = 0; i < out_grad_shape.size(); i++) {
new_shape.push_back(block_shape[i]);
new_shape.push_back(y_shape[i]);
}
auto new_out_grad = reshape<T>(out_grad, new_shape);

// transpose out_grad
std::vector<size_t> permute_order;
for (size_t i = 0; i < out_grad_shape.size(); ++i) {
permute_order.push_back(i * 2);
}
for (size_t i = 0; i < out_grad_shape.size(); ++i) {
permute_order.push_back(i * 2 + 1);
std::deque<Tensor> blocks;
blocks.push_front(out_grad);
for (size_t i = 0; i < x_shape.size(); i++) {
std::vector<Tensor> tmp_block;
while (!blocks.empty()) {
auto block = blocks.front();
blocks.pop_front();
std::vector<int> split_vec(static_cast<int>(x_shape[i]),
static_cast<int>(y_shape[i]));
std::vector<Tensor> block_split =
split<T>(block, IntArray(split_vec), i);

for (auto& b : block_split) {
tmp_block.push_back(b);
}
}
for (auto& tem_b : tmp_block) {
blocks.push_back(tem_b);
}
}

std::vector<int> permute_order_int(permute_order.begin(),
permute_order.end());
auto transposed_out_grad = transpose<T>(new_out_grad, permute_order_int);

// unsqueeze
y_dim.insert(y_dim.begin(), -1);
auto blocks = reshape<T>(transposed_out_grad, y_dim);
std::vector<Tensor> vec(blocks.begin(), blocks.end());
// auto out_grad_tmp = reshape<T>(concat<T>(vec, 0), y_dim);
auto out_grad_tmp = backend::stack<T>(vec, 0);

std::vector<size_t> range;
for (size_t i = 0; i <= y_shape.size(); i++) {
Expand All @@ -3168,7 +3166,7 @@ void kron_grad(const Tensor& x,
}

std::vector<int> range_int(range.begin(), range.end());
auto sum_tensor = sum<T>(blocks * y_, IntArray(range_int));
auto sum_tensor = sum<T>(out_grad_tmp * y_, IntArray(range_int));
auto x_grad_tmp = reshape<T>(sum_tensor, x.shape());

set_output<T>(x_grad_tmp, x_grad);
Expand Down
30 changes: 29 additions & 1 deletion test/legacy_test/test_kron_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def setUp(self):
self.public_python_api = paddle.kron
self.dtype = self._init_dtype()
x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype)
y = np.random.uniform(size=(10, 10)).astype(self.dtype)
y = np.random.uniform(size=(100)).astype(self.dtype)
out_ref = np.kron(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out_ref}
Expand All @@ -97,6 +97,34 @@ def setUp(self):
self.outputs = {'Out': out_ref}


class TestKronOp4(TestKronOp):
def setUp(self):
self.op_type = "kron"
self.prim_op_type = "prim"
self.python_api = paddle.kron
self.public_python_api = paddle.kron
self.dtype = self._init_dtype()
x = np.random.uniform(size=(5, 10, 5, 5)).astype(self.dtype)
y = np.random.uniform(size=(10, 5, 4, 5)).astype(self.dtype)
out_ref = np.kron(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out_ref}


class TestKronOp5(TestKronOp):
def setUp(self):
self.op_type = "kron"
self.prim_op_type = "prim"
self.python_api = paddle.kron
self.public_python_api = paddle.kron
self.dtype = self._init_dtype()
x = np.random.uniform(size=(2, 3, 4, 5, 4)).astype(self.dtype)
y = np.random.uniform(size=(2, 2, 4, 5, 4)).astype(self.dtype)
out_ref = np.kron(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out_ref}


class TestKronFP16Op(TestKronOp):
def _init_dtype(self):
return "float16"
Expand Down

0 comments on commit 5554d08

Please sign in to comment.