From be5f0b92f3fce4fc8ade9fd24cf26a64552b1564 Mon Sep 17 00:00:00 2001 From: doggy-tao <3160391266@qq.com> Date: Thu, 14 Nov 2024 10:54:03 +0800 Subject: [PATCH] [Prim][PIR] Add back_decomp and support dynamic shape for korn_grad (#69108) * add back_decomp for korn_grad(unfinished) * fixed codes * modified code and added test files * modified test file * modified code * modified code * modified kron_grad() and add some test files * modified code of x_grad(kron_grad) in details.h * modified details.h and add dynamic shape for kron_grad * modified details.h --- .../fluid/primitive/codegen/decomp_vjp_gen.py | 1 + .../decomp_rule/decomp_vjp/details.h | 253 ++++++++++++++++++ python/paddle/autograd/backward_utils.py | 1 + test/legacy_test/test_kron_op.py | 95 ++++++- ..._sub_graph_klmno_backward_dynamic_shape.py | 52 ++++ 5 files changed, 395 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/primitive/codegen/decomp_vjp_gen.py b/paddle/fluid/primitive/codegen/decomp_vjp_gen.py index 5684adb3fbabd3..4780c90e191f2a 100644 --- a/paddle/fluid/primitive/codegen/decomp_vjp_gen.py +++ b/paddle/fluid/primitive/codegen/decomp_vjp_gen.py @@ -93,6 +93,7 @@ 'fmax_grad', 'fmin_grad', 'dot_grad', + 'kron_grad', ] OTHER_PRIM_VJP_OPS = [ diff --git a/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h b/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h index 5907310950f345..fe6504025fbe51 100644 --- a/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h +++ b/paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h @@ -2916,6 +2916,259 @@ void argsort_grad(const Tensor& indices, } } +template +void kron_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + Tensor* x_grad, + Tensor* y_grad) { + if (x_grad) { + Tensor zero = full({1}, 0, DataType::INT32); + Tensor x_grad_tmp; + if (has_dynamic_shape(x.shape()) || has_dynamic_shape(y.shape())) { + Tensor x_ = x; + Tensor y_ = y; + auto diff = std::abs(static_cast(x.dims().size()) - + static_cast(y.dims().size())); + while (diff--) { + if (x_.dims().size() > y_.dims().size()) { + y_ = backend::unsqueeze(y_, zero); + } else { + x_ = backend::unsqueeze(x_, zero); + } + } + + // tile + std::vector x_shape_vec; + for (int64_t i = 0; i < x_.dims().size(); ++i) { + auto x_shape_slice = get_slice(shape(x_), i); + x_shape_vec.push_back(x_shape_slice); + } + + auto y_tile = backend::tile(y_, shape(x_)); + + auto out_grad_tmp = y_tile * out_grad; + + std::vector out_grad_shape_vec; + for (int64_t i = 0; i < out_grad.dims().size(); ++i) { + auto out_grad_shape_slice = get_slice(shape(out_grad), i); + out_grad_shape_vec.push_back(out_grad_shape_slice); + } + if (x_shape_vec.size() != 0) { + while (true) { + std::vector expand_shape_vec; + for (int64_t i = 0; i < out_grad_tmp.dims().size(); ++i) { + auto expand_shape = get_slice(shape(out_grad_tmp), i); + expand_shape_vec.push_back(expand_shape); + } + int num_reduce = 0; + while (x_shape_vec.size() != 0 && expand_shape_vec.size() <= 8) { + Tensor repeat = x_shape_vec.back(); + auto orig_size = + cast(out_grad_shape_vec.back() / repeat, DataType::INT32); + size_t out_grad_last_index = out_grad_shape_vec.size() - 1; + expand_shape_vec[out_grad_last_index] = repeat; + expand_shape_vec.insert( + expand_shape_vec.begin() + out_grad_shape_vec.size(), + orig_size); + + x_shape_vec.pop_back(); + out_grad_shape_vec.pop_back(); + ++num_reduce; + } + + int axis = static_cast(out_grad_shape_vec.size()) + 1; + std::vector reduce_axes_vec; + for (int i = 0; i < num_reduce; ++i) { + reduce_axes_vec.push_back(full({1}, axis, DataType::INT32)); + axis += 2; + } + + out_grad_tmp = + backend::reshape(out_grad_tmp, concat(expand_shape_vec)); + out_grad_tmp = + backend::sum(out_grad_tmp, concat(reduce_axes_vec)); + if (x_shape_vec.size() == 0) { + break; + } + } + } + x_grad_tmp = backend::reshape(out_grad_tmp, shape(x)); + } else { + auto x_shape = x.shape(); + auto y_shape = y.shape(); + + auto diff = std::abs(static_cast(x_shape.size()) - + static_cast(y_shape.size())); + for (int i = 0; i < diff; i++) { + if (x_shape.size() > y_shape.size()) { + y_shape.insert(y_shape.begin(), 1); + } else { + x_shape.insert(x_shape.begin(), 1); + } + } + + auto x_ = reshape(x, x_shape); + auto y_ = reshape(y, y_shape); + + // tile + std::vector x_dim = common::vectorize(x_.dims()); + auto y_tile = tile(y_, x_dim); + + auto out_grad_tmp = y_tile * out_grad; + + std::vector out_grad_shape(out_grad_tmp.shape()); + + if (x_dim.size() != 0) { + while (true) { + std::vector expand_shape(out_grad_tmp.shape()); + + int num_reduce = 0; + while (x_dim.size() != 0 && expand_shape.size() <= 8) { + int64_t repeat = x_dim.back(); + int64_t orig_size = out_grad_shape.back() / repeat; + size_t out_grad_last_index = out_grad_shape.size() - 1; + + expand_shape[out_grad_last_index] = repeat; + expand_shape.insert( + expand_shape.begin() + out_grad_shape.size(), 1, orig_size); + + x_dim.pop_back(); + out_grad_shape.pop_back(); + ++num_reduce; + } + + int64_t axis = static_cast(out_grad_shape.size()) + 1; + std::vector reduce_axes; + for (int i = 0; i < num_reduce; ++i) { + reduce_axes.push_back(axis); + axis += 2; + } + + out_grad_tmp = reshape(out_grad_tmp, expand_shape); + out_grad_tmp = sum(out_grad_tmp, reduce_axes); + + if (x_dim.size() == 0) { + break; + } + } + } + x_grad_tmp = reshape(out_grad_tmp, x.shape()); + } + set_output(x_grad_tmp, x_grad); + } + if (y_grad) { + Tensor zero = full({1}, 0, DataType::INT32); + auto x_cast = ConverToMT(x); + auto out_grad_cast = ConverToMT(out_grad); + Tensor out_grad_tmp; + Tensor y_grad_tmp; + + if (has_dynamic_shape(x_cast.shape()) || has_dynamic_shape(y.shape())) { + Tensor x_ = x_cast; + Tensor y_ = y; + auto diff = std::abs(static_cast(x_cast.dims().size()) - + static_cast(y.dims().size())); + while (diff--) { + if (x_.dims().size() > y_.dims().size()) { + y_ = backend::unsqueeze(y_, zero); + } else { + x_ = backend::unsqueeze(x_, zero); + } + } + + std::vector x_shape_vec; + for (int64_t i = 0; i < x_.dims().size(); ++i) { + auto x_shape_slice = get_slice(shape(x_), i); + x_shape_vec.push_back(x_shape_slice); + } + + for (int64_t i = 0; i < x_.dims().size(); ++i) { + auto y_shape_slice = get_slice(shape(y_), i); + auto x_shape_slice = get_slice(shape(x_), i); + auto y_shape_tile = backend::tile(y_shape_slice, x_shape_slice); + x_ = backend::repeat_interleave_with_tensor_index( + x_, y_shape_tile, i); + } + out_grad_tmp = out_grad_cast * x_; + + std::vector out_grad_shape_vec; + for (int64_t i = 0; i < out_grad.dims().size(); ++i) { + auto out_grad_shape_slice = get_slice(shape(out_grad_cast), i); + out_grad_shape_vec.push_back(out_grad_shape_slice); + } + + if (x_shape_vec.size() != 0) { + while (true) { + std::vector expand_shape_vec; + for (int64_t i = 0; i < out_grad_tmp.dims().size(); ++i) { + auto expand_shape = get_slice(shape(out_grad_tmp), i); + expand_shape_vec.push_back(expand_shape); + } + int num_reduce = 0; + while (x_shape_vec.size() != 0 && expand_shape_vec.size() <= 8) { + auto repeat = x_shape_vec.back(); + auto orig_size = + cast(out_grad_shape_vec.back() / repeat, DataType::INT32); + size_t out_grad_last_index = out_grad_shape_vec.size() - 1; + expand_shape_vec[out_grad_last_index] = repeat; + expand_shape_vec.insert( + expand_shape_vec.begin() + out_grad_shape_vec.size(), + orig_size); + + x_shape_vec.pop_back(); + out_grad_shape_vec.pop_back(); + ++num_reduce; + } + int axis = static_cast(out_grad_shape_vec.size()); + std::vector reduce_axes_vec; + for (int i = 0; i < num_reduce; ++i) { + reduce_axes_vec.push_back(full({1}, axis, DataType::INT32)); + axis += 2; + } + out_grad_tmp = + backend::reshape(out_grad_tmp, concat(expand_shape_vec)); + out_grad_tmp = + backend::sum(out_grad_tmp, concat(reduce_axes_vec)); + + if (x_shape_vec.size() == 0) { + break; + } + } + } + y_grad_tmp = backend::reshape( + ConverToOrig(out_grad_tmp, out_grad.dtype()), shape(y)); + } else { + auto x_shape = x_cast.shape(); + auto y_shape = y.shape(); + + auto diff = std::abs(static_cast(x_shape.size()) - + static_cast(y_shape.size())); + for (int i = 0; i < diff; i++) { + if (x_shape.size() > y_shape.size()) { + y_shape.insert(y_shape.begin(), 1); + } else { + x_shape.insert(x_shape.begin(), 1); + } + } + + auto x_ = reshape(x_cast, x_shape); + auto y_ = reshape(y, y_shape); + + std::vector x_dim = common::vectorize(x_.dims()); + for (size_t i = 0; i < x_dim.size(); ++i) { + x_ = repeat_interleave(x_, y_shape[i], i); + } + out_grad_tmp = out_grad_cast * x_; + + tile_grad(y_, out_grad_tmp, IntArray(x_dim), &y_grad_tmp); + y_grad_tmp = + reshape(ConverToOrig(y_grad_tmp, y.dtype()), y.shape()); + } + set_output(y_grad_tmp, y_grad); + } +} + } // namespace details } // namespace primitive } // namespace paddle diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index 666c9f061315a3..9ba326d1637ae6 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -53,6 +53,7 @@ "pd_op.gather_nd", "pd_op.gelu", "pd_op.hardswish", + "pd_op.kron", "pd_op.kthvalue", "pd_op.layer_norm", "pd_op.leaky_relu", diff --git a/test/legacy_test/test_kron_op.py b/test/legacy_test/test_kron_op.py index 2026bf26b42210..1f16a05884db23 100644 --- a/test/legacy_test/test_kron_op.py +++ b/test/legacy_test/test_kron_op.py @@ -26,7 +26,9 @@ class TestKronOp(OpTest): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(10, 10)).astype(self.dtype) y = np.random.uniform(size=(10, 10)).astype(self.dtype) @@ -41,22 +43,41 @@ def test_check_output(self): self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X', 'Y'], 'Out', check_pir=True) + self.check_grad( + ['X', 'Y'], + 'Out', + check_pir=True, + check_prim_pir=True, + ) def test_check_grad_ignore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set('X'), check_pir=True) + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set('X'), + check_pir=True, + check_prim_pir=True, + ) def test_check_grad_ignore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + check_pir=True, + check_prim_pir=True, + ) class TestKronOp2(TestKronOp): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) - y = np.random.uniform(size=(10, 10)).astype(self.dtype) + y = np.random.uniform(size=(100)).astype(self.dtype) out_ref = np.kron(x, y) self.inputs = {'X': x, 'Y': y} self.outputs = {'Out': out_ref} @@ -65,7 +86,9 @@ def setUp(self): class TestKronOp3(TestKronOp): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = self._init_dtype() x = np.random.uniform(size=(10, 10)).astype(self.dtype) y = np.random.uniform(size=(5, 5, 4)).astype(self.dtype) @@ -74,6 +97,48 @@ def setUp(self): self.outputs = {'Out': out_ref} +class TestKronOp4(TestKronOp): + def setUp(self): + self.op_type = "kron" + self.prim_op_type = "prim" + self.python_api = paddle.kron + self.public_python_api = paddle.kron + self.dtype = self._init_dtype() + x = np.random.uniform(size=(5, 5, 5)).astype(self.dtype) + y = np.random.uniform(size=(2, 3, 4, 2, 2, 3)).astype(self.dtype) + out_ref = np.kron(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + +class TestKronOp5(TestKronOp): + def setUp(self): + self.op_type = "kron" + self.prim_op_type = "prim" + self.python_api = paddle.kron + self.public_python_api = paddle.kron + self.dtype = self._init_dtype() + x = np.random.uniform(size=(2, 3, 4, 3, 2)).astype(self.dtype) + y = np.random.uniform(size=(10, 10)).astype(self.dtype) + out_ref = np.kron(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + +class TestKronOp6(TestKronOp): + def setUp(self): + self.op_type = "kron" + self.prim_op_type = "prim" + self.python_api = paddle.kron + self.public_python_api = paddle.kron + self.dtype = self._init_dtype() + x = np.random.uniform(size=(10, 10, 2)).astype(self.dtype) + y = np.random.uniform(size=(2, 3, 4, 3, 2)).astype(self.dtype) + out_ref = np.kron(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + class TestKronFP16Op(TestKronOp): def _init_dtype(self): return "float16" @@ -87,7 +152,9 @@ def _init_dtype(self): class TestKronBF16Op(TestKronOp): def setUp(self): self.op_type = "kron" + self.prim_op_type = "prim" self.python_api = paddle.kron + self.public_python_api = paddle.kron self.dtype = np.uint16 self.np_dtype = "float32" x = np.random.uniform(size=(10, 10)).astype(self.np_dtype) @@ -106,17 +173,31 @@ def test_check_output(self): def test_check_grad(self): self.check_grad_with_place( - self.place, ['X', 'Y'], 'Out', check_pir=True + self.place, + ['X', 'Y'], + 'Out', + check_pir=True, + check_prim_pir=True, ) def test_check_grad_ignore_x(self): self.check_grad_with_place( - self.place, ['Y'], 'Out', no_grad_set=set('X'), check_pir=True + self.place, + ['Y'], + 'Out', + no_grad_set=set('X'), + check_pir=True, + check_prim_pir=True, ) def test_check_grad_ignore_y(self): self.check_grad_with_place( - self.place, ['X'], 'Out', no_grad_set=set('Y'), check_pir=True + self.place, + ['X'], + 'Out', + no_grad_set=set('Y'), + check_pir=True, + check_prim_pir=True, ) diff --git a/test/prim/pir_prim/test_prim_sub_graph_klmno_backward_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_klmno_backward_dynamic_shape.py index 5c442a84918e6f..65fa755841777e 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_klmno_backward_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_klmno_backward_dynamic_shape.py @@ -24,6 +24,10 @@ import paddle +def kron_net(x, y): + return paddle.kron(x, y) + + def kthvalue_net1(x): return paddle.kthvalue(x, k=2)[0] @@ -136,6 +140,54 @@ def multiply_net(x, y): return x * y +class TestPrimKronWithGrad1(TestPrimTwoWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.kron" + self.dtype = "float32" + self.x_shape = [10, 10] + self.init_x_shape = [None, None] + self.y_shape = [5, 5, 4] + self.init_y_shape = [None, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype(self.dtype) + self.net = kron_net + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimKronWithGrad2(TestPrimTwoWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.kron" + self.dtype = "float32" + self.x_shape = [10, 10] + self.init_x_shape = [None, None] + self.y_shape = [5, 5, 4, 3, 2] + self.init_y_shape = [None, None, None, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype(self.dtype) + self.net = kron_net + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimKronWithGrad3(TestPrimTwoWithGrad): + def setUp(self): + np.random.seed(2024) + self.op_name = "pd_op.kron" + self.dtype = "float32" + self.x_shape = [5, 5, 4, 3, 5, 6] + self.init_x_shape = [None, None, None, None, None, None] + self.y_shape = [3, 5, 4] + self.init_y_shape = [None, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.y = np.random.random(self.y_shape).astype(self.dtype) + self.net = kron_net + self.enable_cinn = False + self.tol = 1e-6 + + class TestPrimKthvalueWithGrad1(TestPrimBaseWithGrad): def setUp(self): np.random.seed(2024)