Skip to content

Commit

Permalink
fix failed to save tft model. (PaddlePaddle#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
kehuo authored Jan 30, 2023
1 parent 057653b commit 308db62
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 89 deletions.
5 changes: 5 additions & 0 deletions paddlets/models/anomaly/dl/anomaly_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def save(self, path: str, network_model: bool = False, dygraph_to_static: bool =
optimizer = self._optimizer
network = self._network
callback_container = self._callback_container
loss_fn = self._loss_fn

# _network is inherited from a paddle-related pickle-not-serializable object, so needs to set to None.
self._network = None
Expand All @@ -828,6 +829,9 @@ def save(self, path: str, network_model: bool = False, dygraph_to_static: bool =
# _callback_container contains PaddleBaseModel instances, as PaddleBaseModel contains pickle-not-serializable
# objects `_network` and `_optimizer`, so also needs to set to None.
self._callback_container = None
# loss_fn could possibly contain paddle.Tensor when it is a bound method of a class, thus needs to set to
# None to avoid pickle.dumps failure.
self._loss_fn = None
try:
with open(abs_model_path, "wb") as f:
pickle.dump(self, f)
Expand All @@ -838,6 +842,7 @@ def save(self, path: str, network_model: bool = False, dygraph_to_static: bool =
self._optimizer = optimizer
self._network = network
self._callback_container = callback_container
self._loss_fn = loss_fn
return

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions paddlets/models/forecasting/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def save(self, path: str, network_model: bool = False, dygraph_to_static = True,
optimizer = self._optimizer
network = self._network
callback_container = self._callback_container
loss_fn = self._loss_fn

# _network is inherited from a paddle-related pickle-not-serializable object, so needs to set to None.
self._network = None
Expand All @@ -228,6 +229,9 @@ def save(self, path: str, network_model: bool = False, dygraph_to_static = True,
# _callback_container contains PaddleBaseModel instances, as PaddleBaseModel contains pickle-not-serializable
# objects `_network` and `_optimizer`, so also needs to set to None.
self._callback_container = None
# loss_fn could possibly contain paddle.Tensor when it is a bound method of a class, thus needs to set to
# None to avoid pickle.dumps failure.
self._loss_fn = None
try:
with open(abs_model_path, "wb") as f:
pickle.dump(self, f)
Expand All @@ -238,6 +242,7 @@ def save(self, path: str, network_model: bool = False, dygraph_to_static = True,
self._optimizer = optimizer
self._network = network
self._callback_container = callback_container
self._loss_fn = loss_fn

return

Expand Down
Loading

0 comments on commit 308db62

Please sign in to comment.