Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR / dy2static] Fix tostatic unittest bugs. #58959

Merged
merged 5 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,10 +1015,10 @@ void AppendSetParameter(Program *forward_program,
}
}

void AppendSetParameters(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
int start_point,
std::string name_prefix) {
int AppendSetParameters(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
int start_point,
std::string name_prefix) {
int counter = 0;
std::unordered_set<pir::OpResult> added_op_result;

Expand All @@ -1032,6 +1032,8 @@ void AppendSetParameters(Program *forward_program,
added_op_result.insert(result);
}
}
// return the inserted op.
return counter;
}

SplitedResult SplitForwardBackward(
Expand Down
204 changes: 96 additions & 108 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
from paddle.base.compiler import BuildStrategy
from paddle.base.data_feeder import check_type, convert_dtype
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.framework import use_pir_api
from paddle.optimizer.lr import LRScheduler
from paddle.pir import OpResult, fake_op_result, is_fake_op_result

from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, backend_guard

__all__ = []
Expand All @@ -52,60 +50,57 @@ def __get__(self, instance, cls):
class NestSequence:
"""
A wrapper class that easily to flatten and restore the nest structure of
given sequence.
given sequence. It also remove the duplicate variables in the sequence.
For example:
>>> t = [v1, v2, v1]
>>> m = tolist(t)
[v1, v2]
>>> m.restore([t1, t2])
[t1, t2, t1]
"""

def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input
self.__input_list = self.tolist()
self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check)
def __init__(self, raw_input):
self._raw_input = raw_input
self._var_map, self._var_list = self._tolist()

def tolist(self):
@property
def var_list(self):
return self._var_list

def _tolist(self):
"""
Flattens the nested sequences into single list.
Flattens the nested sequences into single list and remove duplicate variables + non-variable elements.
"""
return paddle.utils.flatten(self.__raw_input)
variable_map = {} # opresult -> list idx
variable_list = []
for value in paddle.utils.flatten(self._raw_input):
if not isinstance(value, OpResult):
continue
if value in variable_map:
# remove duplicate opresults.
continue
variable_map[value] = len(variable_list)
variable_list.append(value)
return variable_map, variable_list

def restore(self, value_list):
"""
Restores the nested sequence from value list.
"""
assert len(self.__input_list) == len(value_list)
return paddle.utils.pack_sequence_as(self.__raw_input, value_list)

def _get_var_ids(self):
var_ids = []
for idx, var in enumerate(self.__input_list):
if isinstance(var, (OpResult, core.eager.Tensor)):
var_ids.append(idx)
assert len(self._var_list) == len(value_list)

return var_ids
def to_value(x):
if isinstance(x, OpResult):
return value_list[self._var_map[x]]
return x

def _check_non_variable(self, need_check):
"""
Raises warning if output of traced function contains non-tensor type values.
"""
if need_check:
warning_types = set()
for var in self.__input_list:
if not isinstance(var, (framework.Variable, core.eager.Tensor)):
warning_types.add(type(var))
if warning_types:
logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".format(
list(warning_types)
)
)

@property
def var_ids(self):
return self.__var_ids
return paddle.utils.pack_sequence_as(
self._raw_input,
list(map(to_value, paddle.utils.flatten(self._raw_input))),
)

def __getitem__(self, item):
return self.__input_list[item]
return self._var_list[item]


class RunableProgram:
Expand Down Expand Up @@ -146,10 +141,7 @@ def convert_name(self, values):
return []
if isinstance(values[0], str):
return values
try:
return [self.get_value_name_map[v] for v in values]
except:
breakpoint()
return [self.get_value_name_map[v] for v in values]

@cached_property
def x_values(self):
Expand Down Expand Up @@ -415,7 +407,7 @@ def __init__(
):
super().__init__()
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
self._outputs = NestSequence(outputs)
self._params, self._param_values = (
parameters if parameters is not None else ([], [])
)
Expand Down Expand Up @@ -458,15 +450,7 @@ def __call__(self, inputs):
"""
in_vars, out_vars = self._prepare(inputs)
attrs = self._prepare_attributes()

# self._sync_lr_value_with_scheduler()

c_run_program_fn = None
if use_pir_api():
c_run_program_fn = _legacy_C_ops.pir_run_program
else:
c_run_program_fn = _legacy_C_ops.run_program
c_run_program_fn(
_legacy_C_ops.pir_run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
Expand All @@ -482,12 +466,8 @@ def __call__(self, inputs):

@cached_property
def origin_runable_program(self):
inputs = list(
filter(lambda x: isinstance(x, OpResult), self._inputs.tolist())
)
outputs = list(
filter(lambda x: isinstance(x, OpResult), self._outputs.tolist())
)
inputs = list(self._inputs.var_list)
outputs = list(self._outputs.var_list)
params = self._param_values
paddle.base.libpaddle.pir.append_set_parameters(
self._origin_main_program,
Expand Down Expand Up @@ -638,7 +618,7 @@ def _need_aggregation(var):
"""
if exist a op whose inputs is var, then return True
"""
if not isinstance(var, framework.Variable) or var.type not in [
if var.type not in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
]:
Expand Down Expand Up @@ -692,7 +672,7 @@ def _insert_aggregation_ops_for_var(target_program, var):
return None

to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist())
filter(_need_aggregation, self._outputs.var_list)
)
for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var)
Expand All @@ -710,31 +690,56 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
params = train_runnable_program.param_values
combined_inputs = list(itertools.chain(inputs, params))
forward_end_idx = len(program.global_block().ops)
if targets:
with backend_guard(self._backend):
check_type(
targets,
'targets',
(OpResult, list, tuple),
'paddle.static.gradients',
)
with ir_static.program_guard(program, None):
grad_info_map = grad(
inputs=combined_inputs, outputs=targets
)
grad_info_map = [None] * len(combined_inputs)
with backend_guard(self._backend):
check_type(
targets,
'targets',
(OpResult, list, tuple),
'paddle.static.gradients',
)
with ir_static.program_guard(program, None):
# create outputs_grad for backward to avoid full and full_like op.
forward_outputs_grads = []
not_stop_gradient_num = 0
for out_op_result in self._outputs.tolist():
for out_op_result in targets:
if out_op_result.stop_gradient is True:
forward_outputs_grads.append(None)
continue
opres = (
program.global_block()
.ops[forward_end_idx + 2 * not_stop_gradient_num + 1]
.results()[0]
forward_outputs_grads.append(fake_op_result())
else:
value = paddle.full_like(
out_op_result,
fill_value=1.0,
dtype=out_op_result.dtype,
)
forward_outputs_grads.append(value)
paddle.base.libpaddle.pir.append_set_parameters(
program,
forward_outputs_grads,
len(program.global_block().ops),
"grad_input_",
)
backward_start_op_index = len(program.global_block().ops)

# call grad to get backward ops.
if (
len(
list(
filter(lambda x: x.stop_gradient is False, targets)
)
)
> 0
):
grad_info_map = grad(
inputs=combined_inputs,
outputs=list(
filter(lambda x: x.stop_gradient is False, targets)
),
grad_outputs=list(
filter(
lambda x: not is_fake_op_result(x),
forward_outputs_grads,
)
),
)
forward_outputs_grads.append(opres)
not_stop_gradient_num += 1

if self._hooker:
(
Expand All @@ -759,9 +764,6 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
)
p_grad_value = list(map(mapping_op_result, grad_info_map[inputs_size:]))
o_grad_value = list(map(mapping_op_result, forward_outputs_grads))
backward_start_op_index = forward_end_idx + 2 * len(
list(filter(lambda r: r.stop_gradient is False, self._outputs))
)

# insert grads name for RunableProgram (we need name for grad_inputs and grad_outputs)
input_grads_to_append = list(
Expand All @@ -772,13 +774,6 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
lambda x: not is_fake_op_result(x), x_grad_value + p_grad_value
)
)
paddle.base.libpaddle.pir.append_set_parameters(
program,
input_grads_to_append,
backward_start_op_index,
"grad_input_",
)
backward_start_op_index += len(input_grads_to_append)
backward_end_op_index = len(program.global_block().ops)
paddle.base.libpaddle.pir.append_set_parameters(
program,
Expand Down Expand Up @@ -878,8 +873,7 @@ def _prepare(self, inputs):
# mapping from name(string) -> Tensor
out_tensor_map = {}

def create_out(var_id):
var = self._outputs[var_id]
def create_out(var):
assert isinstance(var, OpResult)

if id(var) in out_tensor_map:
Expand All @@ -901,7 +895,7 @@ def create_out(var_id):
return out

# Create Tensor to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
out_vars = list(map(create_out, self._outputs.var_list))
return input_vars, out_vars

def _create_scope_vec(self, program_id=None, use_scope_cache=False):
Expand All @@ -923,26 +917,20 @@ def _create_cuda_graph_vec(self):

def _update_stop_gradient(self, out_vars):
# Update stop_gradient for all outputs
def set_stop_gradient(var_id, eager_tensor):
var = self._outputs[var_id]
def set_stop_gradient(var, eager_tensor):
assert isinstance(var, OpResult)
eager_tensor.stop_gradient = var.stop_gradient

for idx, var in zip(self._outputs.var_ids, out_vars):
for idx, var in zip(self._outputs.var_list, out_vars):
set_stop_gradient(idx, var)

def _restore_out(self, out_vars):
"""
Restores same nested outputs by only replacing the Variable with Tensor.
"""

flatten_outputs = self._outputs.tolist()
for i, idx in enumerate(self._outputs.var_ids):
flatten_outputs[idx] = out_vars[i]
outs = self._outputs.restore(flatten_outputs)
outs = self._outputs.restore(out_vars)
if outs is not None and len(outs) == 1:
outs = outs[0]

return outs

@switch_to_static_graph
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_duplicate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir_api

import paddle

Expand Down Expand Up @@ -55,7 +55,7 @@ def _run_static(self):

self.assertEqual(param[0].grad.numpy(), 1.0)

@test_legacy_and_pir
@test_legacy_and_pir_api
def test_ast_to_func(self):
self._run_static()

Expand Down
16 changes: 16 additions & 0 deletions test/ir/pir/test_pir_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,21 @@ def func(x):
np.testing.assert_allclose(out.numpy(), ans.numpy())


class TestDy2staticPir7(unittest.TestCase):
# test basic-indexing __getitem__ for OpResult
def test_basic_network(self):
def func(x):
x = x * 2
x = x + 1
return 1

static_func = paddle.jit.to_static(func, full_graph=True)
x = paddle.randn((2, 3, 4))
x.stop_gradient = False
ans = func(x)
out = static_func(x)
np.testing.assert_allclose(out, ans)


if __name__ == "__main__":
unittest.main()