Skip to content

Commit

Permalink
Do not reset Loops total counters (#8475)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jul 19, 2021
1 parent 3628c31 commit a6fd32a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Do not reset the progress tracking dataclasses total counters ([#8475](https://github.com/PyTorchLightning/pytorch-lightning/pull/8475))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from deprecate import void

import pytorch_lightning as pl
from pytorch_lightning.trainer.progress import BaseProgress, Tracker
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -195,11 +195,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress:
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
if restart_progress:

def restart(tracker: Tracker):
tracker.reset_on_restart()

apply_to_collection(v, Tracker, restart)

apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def __setattr__(self, key: str, value: int) -> None:
raise AttributeError(f"The '{key}' attribute is meant to be unused")
return super().__setattr__(key, value)

def __repr__(self):
def __repr__(self) -> str:
# hide `None` fields
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
return f"{self.__class__.__name__}({', '.join(args)})"

def reset_on_restart(self):
def reset_on_restart(self) -> None:
"""Reset the progress on restart"""
value = self.completed if self.processed is None else self.processed

Expand Down
7 changes: 3 additions & 4 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def val_dataloader(self):
trainer.fit_loop.load_state_dict(checkpoint)
expected = {
"total": {
"ready": total_val_batch,
"started": total_val_batch,
"ready": total_val_batch + 1,
"started": total_val_batch + 1,
"processed": total_val_batch,
"completed": total_val_batch
},
Expand Down Expand Up @@ -555,6 +555,5 @@ def configure_optimizers_multiple(self):
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()
assert state_dict != checkpoint["loops"]["fit_loop"]
# TODO(@carmocca): do not reset for total
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch

0 comments on commit a6fd32a

Please sign in to comment.