Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HotFix] Logging - One epoch delay on training epoch metrics. #4913

Merged
merged 9 commits into from
Dec 1, 2020
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,11 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self, epoch_output):
# inform logger the batch loop has finished
self.trainer.logger_connector.on_train_epoch_end()

self.trainer.call_hook('on_epoch_end', capture=True)
self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True)
self.trainer.logger_connector.on_train_epoch_end()

def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = self._accumulated_batches_reached()
Expand Down
24 changes: 10 additions & 14 deletions tests/trainer/logging/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
import os
from copy import deepcopy
from unittest import mock

import torch
import pytest
from copy import deepcopy
from pytorch_lightning.trainer import Trainer
import torch

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset

Expand Down Expand Up @@ -68,13 +69,14 @@ def training_step(self, batch, batch_idx):
self.train_losses.append(loss)

self.log("train_loss", loss, on_step=True, on_epoch=True)

return {"loss": loss}

def on_train_epoch_end(self, outputs):
# save objects as it will be reset at the end of epoch.
def training_step_end(self, *_):
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)

model = TestModel()
model.training_epoch_end = None
model.val_dataloader = None

trainer = Trainer(
Expand Down Expand Up @@ -144,11 +146,6 @@ def __init__(self):

@Helper.decorator_with_arguments(fx_name="training_step")
def training_step(self, batch, batch_idx, hiddens):
try:
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
except Exception as e:
print(e)

self.test_hidden = torch.rand(1)

x_tensor, y_list = batch
Expand Down Expand Up @@ -178,8 +175,7 @@ def train_dataloader(self):
sampler=None,
)

def on_train_epoch_end(self, outputs):
# save objects as it will be reset at the end of epoch.
def training_step_end(self, *_):
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)

model = TestModel()
Expand Down
31 changes: 31 additions & 0 deletions tests/trainer/logging_tests/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,3 +886,34 @@ def get_metrics_at_idx(idx):

expected = torch.stack(model.val_losses[4:]).mean()
assert get_metrics_at_idx(6)["valid_loss_1"] == expected


def test_progress_bar_dict_contains_values_on_test_epoch_end(tmpdir):
class TestModel(BoringModel):
def test_step(self, *args):
self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True)

def test_epoch_end(self, *_):
self.epoch_end_called = True
self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True,
on_epoch=True, sync_dist=True, sync_dist_op='sum')

def on_test_epoch_end(self, *_):
self.on_test_epoch_end_called = True
assert self.trainer.progress_bar_dict["foo"] == self.current_epoch
assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=1,
num_sanity_val_steps=2,
checkpoint_callback=False,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
model = TestModel()
trainer.test(model)
assert model.epoch_end_called
assert model.on_test_epoch_end_called
32 changes: 32 additions & 0 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,35 @@ def validation_step(self, batch, batch_idx):

assert trainer.logged_metrics['foo'] == fake_result
assert trainer.logged_metrics['bar'] == fake_result


def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir):
class TestModel(BoringModel):
def training_step(self, *args):
self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True)
return super().training_step(*args)

def on_epoch_end(self):
self.epoch_end_called = True
self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True,
on_epoch=True, sync_dist=True, sync_dist_op='sum')

def on_train_epoch_end(self, *_):
self.on_train_epoch_end_called = True
assert self.trainer.progress_bar_dict["foo"] == self.current_epoch
assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
checkpoint_callback=False,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
model = TestModel()
trainer.fit(model)
assert model.epoch_end_called
assert model.on_train_epoch_end_called