diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index a7479bc7166694..8cabcd812d1d34 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -412,7 +412,8 @@ void multiply_double_grad(const Tensor& x, } } else { - x_grad = nullptr; + auto dx = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + set_output(dx, x_grad); } } if (y_grad) { @@ -433,22 +434,22 @@ void multiply_double_grad(const Tensor& x, set_output(dy, y_grad); } } else { - y_grad = nullptr; + auto dy = full(phi::vectorize(y.dims()), 0.0, y.dtype()); + set_output(dy, y_grad); } } if (grad_out_grad) { + Tensor ddout; if (grad_x_grad && grad_y_grad) { - auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x; - set_output(ddout, grad_out_grad); + ddout = grad_x_grad.get() * y + grad_y_grad.get() * x; } else if (grad_x_grad) { - auto ddout = grad_x_grad.get() * y; - set_output(ddout, grad_out_grad); + ddout = grad_x_grad.get() * y; } else if (grad_y_grad) { - auto ddout = grad_y_grad.get() * x; - set_output(ddout, grad_out_grad); + ddout = grad_y_grad.get() * x; } else { - grad_out_grad = nullptr; + ddout = full(phi::vectorize(grad_out.dims()), 0.0, grad_out.dtype()); } + set_output(ddout, grad_out_grad); } } @@ -461,10 +462,10 @@ void add_double_grad(const Tensor& y, Tensor* grad_out_grad) { if (grad_out_grad) { // ddout = ddx + ddy + Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); if (!grad_x_grad && !grad_y_grad) { - grad_out_grad = nullptr; + set_output(ddout, grad_out_grad); } else { - Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); if (grad_x_grad) { ddout = ddout + grad_x_grad.get(); } diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 2a3404d95e0ffc..c737134f474ec1 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -2489,9 +2489,11 @@ def calc_gradient_helper( block, targets, inputs, block_no_grad_set, op_path_dict ) - # only for composite to add grad_op input, - # tmp_targets includes targets and other outputs - # of the same forward op who create targets + # only for composite to add grad_var of the last forward op + # who has more than one output, but targets only has one, + # so targets_gradients only add one grad_var, + # eg: op1 -> op2 -> var1 / var2 targets = var1, + # targets_gradients = var1_grad, need to add var2_grad here. tmp_targets = targets if core._is_bwd_prim_enabled():