Skip to content

Commit

Permalink
add version control for export (#552)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Jan 8, 2025
1 parent 91885bc commit aeb5b59
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
17 changes: 12 additions & 5 deletions paddlets/models/anomaly/dl/anomaly_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,8 @@ def save(self,
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
save_name = internal_filename_map["network_model"]
paddle_version = version.parse(paddle.__version__)
if export_with_pir:
paddle_version = version.parse(paddle.__version__)
assert (
paddle_version >= version.parse('3.0.0b2') or
paddle_version == version.parse('0.0.0')
Expand All @@ -959,10 +959,17 @@ def save(self,
paddle.jit.save(layer,
os.path.join(abs_root_path, save_name))
else:
layer.forward.rollback()
with paddle.pir_utils.OldIrGuard():
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
if paddle_version >= version.parse(
'3.0.0b2') or paddle_version == version.parse(
'0.0.0'):
layer.forward.rollback()
with paddle.pir_utils.OldIrGuard():
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
paddle.jit.save(layer,
os.path.join(abs_root_path,
save_name))
else:
paddle.jit.save(layer,
os.path.join(abs_root_path,
save_name))
Expand Down
17 changes: 12 additions & 5 deletions paddlets/models/classify/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,8 @@ def save(self,
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
save_name = internal_filename_map["network_model"]
paddle_version = version.parse(paddle.__version__)
if export_with_pir:
paddle_version = version.parse(paddle.__version__)
assert (
paddle_version >= version.parse('3.0.0b2') or
paddle_version == version.parse('0.0.0')
Expand All @@ -796,10 +796,17 @@ def save(self,
paddle.jit.save(layer,
os.path.join(abs_root_path, save_name))
else:
layer.forward.rollback()
with paddle.pir_utils.OldIrGuard():
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
if paddle_version >= version.parse(
'3.0.0b2') or paddle_version == version.parse(
'0.0.0'):
layer.forward.rollback()
with paddle.pir_utils.OldIrGuard():
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
paddle.jit.save(layer,
os.path.join(abs_root_path,
save_name))
else:
paddle.jit.save(layer,
os.path.join(abs_root_path,
save_name))
Expand Down
17 changes: 12 additions & 5 deletions paddlets/models/forecasting/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def save(self,
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
save_name = internal_filename_map["network_model"]
paddle_version = version.parse(paddle.__version__)
if export_with_pir:
paddle_version = version.parse(paddle.__version__)
assert (
paddle_version >= version.parse('3.0.0b2') or
paddle_version == version.parse('0.0.0')
Expand All @@ -207,10 +207,17 @@ def save(self,
paddle.jit.save(layer,
os.path.join(abs_root_path, save_name))
else:
layer.forward.rollback()
with paddle.pir_utils.OldIrGuard():
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
if paddle_version >= version.parse(
'3.0.0b2') or paddle_version == version.parse(
'0.0.0'):
layer.forward.rollback()
with paddle.pir_utils.OldIrGuard():
layer = paddle.jit.to_static(
self._network, input_spec=input_spec)
paddle.jit.save(layer,
os.path.join(abs_root_path,
save_name))
else:
paddle.jit.save(layer,
os.path.join(abs_root_path,
save_name))
Expand Down

0 comments on commit aeb5b59

Please sign in to comment.