Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Nov 16, 2023
1 parent a041eb0 commit fbc5064
Showing 1 changed file with 40 additions and 32 deletions.
72 changes: 40 additions & 32 deletions python/paddle/jit/dy2static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,49 @@ def __init__(
for var in self._inputs:
if isinstance(var, framework.Variable):
self._in_var_names.append(var.desc.name())

self._out_var_descs = [
self._outputs[var_id].desc for var_id in self._outputs.var_ids
]

self._attrs = [
'is_test',
not self.training,
'program_id',
self.program_id,
]

if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
# the correct names of the parameter grads from program. And out grads are similar to above.
self._attrs.extend(
(
'param_grad_names',
self._grad_var_names.get('param', []),
'out_grad_names',
self._grad_var_names.get('out', []),
'x_grad_names',
self._grad_var_names.get('x', []),
)
)
if self._cuda_graph_capture_mode:
self._attrs.extend(
(
'cuda_graph_capture_mode',
self._cuda_graph_capture_mode,
'cuda_graph_pool_id',
self._cuda_graph_pool_id,
)
)

self._attrs.extend(
[
"x_names",
self._in_var_names,
]
)

def __call__(self, inputs):
"""
Execute static graph by Interpreter and Return dynamic Tensors.
Expand All @@ -237,7 +276,6 @@ def __call__(self, inputs):
out_vars = self._prepare_outputs()
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
attrs.extend(["x_names", in_var_names])

self._sync_lr_value_with_scheduler()

Expand Down Expand Up @@ -267,8 +305,6 @@ def sot_call(self, inputs):
out_vars = self._prepare_outputs()
self._cast_fp16_if_pure_fp16(inputs)
attrs = self._prepare_attributes()
attrs.extend(["x_names", self._in_var_names])
self._sync_lr_value_with_scheduler()

_legacy_C_ops.run_program(
self._valid_vars(inputs),
Expand Down Expand Up @@ -770,35 +806,7 @@ def _prepare_attributes(self):
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
'is_test',
not self.training,
'program_id',
self.program_id,
]

if self.training:
# NOTE: In the case of higher-order gradient, the names of the parameter grads may be like
# `grad/grad/grad/linear_0.w_0@GRAD` instead of simply `linear_0.w_0@GRAD`, so we get
# the correct names of the parameter grads from program. And out grads are similar to above.
attrs.extend(
(
'param_grad_names',
self._grad_var_names.get('param', []),
'out_grad_names',
self._grad_var_names.get('out', []),
'x_grad_names',
self._grad_var_names.get('x', []),
)
)
if self._cuda_graph_capture_mode:
attrs.extend(
(
'cuda_graph_capture_mode',
self._cuda_graph_capture_mode,
'cuda_graph_pool_id',
self._cuda_graph_pool_id,
)
)
] + self._attrs
return attrs

@switch_to_static_graph
Expand Down

0 comments on commit fbc5064

Please sign in to comment.