Skip to content

Commit

Permalink
support no grad in vjp (#57294)
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit authored Sep 14, 2023
1 parent 0424fb7 commit 5934505
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
25 changes: 25 additions & 0 deletions paddle/fluid/primitive/utils/static_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,34 @@ void set_output<LazyTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
x->set_impl(x_tmp.impl());
}

/**
* @brief set output with no grads in new ir.
*
* In new ir, we use None type to express
* that value is not available.
* Some outputs in vjp are marked as unnecessary
* by stop_gradient with True. Therefore the
* type of those outputs that are unnecessary will
* be set with None.
*
*/
void SetOutputWithNoGrads(
const std::vector<std::vector<Tensor>>& outputs,
const std::vector<std::vector<bool>>& stop_gradients) {
for (size_t i = 0; i < outputs.size(); ++i) {
for (size_t j = 0; j < outputs[i].size(); ++j) {
if (stop_gradients[i][j]) {
std::static_pointer_cast<primitive::LazyTensor>(outputs[i][j].impl())
->set_empty_type();
}
}
}
}

std::vector<std::vector<Tensor>> ConstructVjpResultByStopGradients(
const std::vector<std::vector<Tensor>>& outputs,
const std::vector<std::vector<bool>>& stop_gradients) {
SetOutputWithNoGrads(outputs, stop_gradients);
std::vector<std::vector<Tensor>> vjp_results(outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
vjp_results[i].reserve(outputs[i].size());
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
return get_reduce_dims_from_out(out_dims, x_dims);
}

void SetOutputWithNoGrads(const std::vector<std::vector<Tensor>>& outputs,
const std::vector<std::vector<bool>>& stop_gradients);

std::vector<std::vector<Tensor>> ConstructVjpResultByStopGradients(
const std::vector<std::vector<Tensor>>& outputs,
const std::vector<std::vector<bool>>& stop_gradients);
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/prim/test_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ TEST(VJP, SplitBackwardTest) {
paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {true}, {true}};
std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<pir::OpResult>> out_grads{{op3.result(0), op4.out()}};
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd_op.split");

Expand Down
10 changes: 4 additions & 6 deletions test/prim/new_ir_prim/test_vjp_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,14 @@ def test_sum_grad_prim(self):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
dout = newir_program.global_block().ops[-3].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [True]]
stop_gradients = [[False]]
sum_op = newir_program.global_block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
expand_op = newir_program.global_block().ops[-1]
self.assertEqual(len(grad_outs), 2)
self.assertEqual(len(grad_outs), 1)
self.assertEqual(len(newir_program.global_block().ops), 8)
self.assertEqual(expand_op.result(0), grad_outs[0][0])
self.assertEqual(grad_outs[1][0], None)
all_op_names = [
"pd_op.full",
"pd_op.full",
Expand All @@ -157,15 +156,14 @@ def test_sum_grad_no_prim(self):
paddle.framework.core._set_prim_backward_enabled(False)
dout = newir_program.global_block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [True]]
stop_gradients = [[False]]
sum_op = newir_program.global_block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
self.assertEqual(len(grad_outs), 2)
self.assertEqual(len(grad_outs), 1)
self.assertEqual(
grad_outs[0][0].get_defining_op().name(), "pd_op.sum_grad"
)
self.assertEqual(grad_outs[1][0], None)
self.assertEqual(len(newir_program.global_block().ops), 5)


Expand Down

0 comments on commit 5934505

Please sign in to comment.