Skip to content

Commit

Permalink
Add amp to static fix (#543)
Browse files Browse the repository at this point in the history
* add amp and to_static

* fix amp and to_static
  • Loading branch information
Sunting78 authored Nov 29, 2024
1 parent 6b09113 commit e163352
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
6 changes: 3 additions & 3 deletions paddlets/models/anomaly/dl/anomaly_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def save(self,
network = self._network
callback_container = self._callback_container
loss_fn = self._loss_fn
if self.use_amp:
if hasattr(self, 'use_amp') and self.use_amp:
scaler = self.scaler
self.scaler = None

Expand Down Expand Up @@ -1034,8 +1034,8 @@ def save(self,
self._network = network
self._callback_container = callback_container
self._loss_fn = loss_fn
if self.use_amp:
scaler = self.scaler
if hasattr(self, 'use_amp') and self.use_amp:
self.scaler = scaler
return

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions paddlets/models/classify/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def save(self,
network = self._network
callback_container = self._callback_container
loss_fn = self._loss_fn
if self.use_amp:
if hasattr(self, 'use_amp') and self.use_amp:
scaler = self.scaler
self.scaler = None

Expand Down Expand Up @@ -859,7 +859,7 @@ def save(self,
self._network = network
self._callback_container = callback_container
self._loss_fn = loss_fn
if self.use_amp:
if hasattr(self, 'use_amp') and self.use_amp:
self.scaler = scaler
return

Expand Down
5 changes: 2 additions & 3 deletions paddlets/models/forecasting/dl/paddle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def save(self,
network = self._network
callback_container = self._callback_container
loss_fn = self._loss_fn
if self.use_amp:
if hasattr(self, 'use_amp') and self.use_amp:
scaler = self.scaler
self.scaler = None

Expand Down Expand Up @@ -283,9 +283,8 @@ def save(self,
self._network = network
self._callback_container = callback_container
self._loss_fn = loss_fn
if self.use_amp:
if hasattr(self, 'use_amp') and self.use_amp:
self.scaler = scaler

return

@staticmethod
Expand Down

0 comments on commit e163352

Please sign in to comment.