Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ordering of hooks #8670

Closed
mmgxa opened this issue Aug 2, 2021 · 9 comments
Closed

Ordering of hooks #8670

mmgxa opened this issue Aug 2, 2021 · 9 comments
Labels
bug Something isn't working help wanted Open to be worked on working as intended Working as intended

Comments

@mmgxa
Copy link

mmgxa commented Aug 2, 2021

🐛 Bug

In PL 1.4, the order of hooks has changed.

in PL 1.3.8, it was

on_train_epoch_start
training_step
training_step
training_step
training_step
training_epoch_end
on_epoch_end
on_validation_epoch_start
validation_step
validation_step
validation_step
validation_step
validation_epoch_end
on_epoch_end

Now, in PL1.4, it is

on_train_epoch_start
training_step
training_step
training_step
training_step
on_validation_epoch_start
validation_step
validation_step
validation_step
validation_step
validation_epoch_end
on_epoch_end
training_epoch_end
on_epoch_end

i.e. training_epoch_end runs after validation_epoch_end instead of the last training_step, which doesn't make sense since on_epoch_end is 'just next to it'. Also, note the proximity of the two on_epoch_end in PL 1.4

To Reproduce

You can use the following Colab link:
https://colab.research.google.com/github/mmg10/pl_bug/blob/main/pl_bug_138.ipynb

https://colab.research.google.com/github/mmg10/pl_bug/blob/main/pl_bug_140.ipynb

Environment

PyTorch Lightning 1.3.8 and 1.4.0 respectively

Significance

In PL 1.3.8, we could get the average of training loss across batches via

def training_epoch_end(self, outputs):
    self.avg_train_loss = torch.stack([x['loss'] for x in outputs]).mean().item()

but now we can't, Note that we still can run the following

def validation_epoch_end(self, outputs):
     avg_valid_loss = torch.stack([x['loss'] for x in outputs]).mean().item()

since the validation_epoch_end is preceeded by the last validation_step

@mmgxa mmgxa added bug Something isn't working help wanted Open to be worked on labels Aug 2, 2021
@tchaton
Copy link
Contributor

tchaton commented Aug 2, 2021

@carmocca Any idea there ?

@tchaton tchaton added the priority: 0 High priority task label Aug 2, 2021
@Borda
Copy link
Member

Borda commented Aug 2, 2021

looks similar to #8654

@carmocca carmocca added working as intended Working as intended and removed priority: 0 High priority task labels Aug 2, 2021
@carmocca
Copy link
Contributor

carmocca commented Aug 2, 2021

The order was changed in #7357. See the linked PR for its reasoning.

but now we can't

Can you elaborate on why you can't anymore? Is it because you use the loss keyword during both training and validation?

@mmgxa
Copy link
Author

mmgxa commented Aug 3, 2021

@carmocca
I mentioned it. The outputs dictionary contains losses for all batches. It still does for the validation_epoch_end since it is run after the last validations_step. Not anymore for the training_epoch_end

@ananthsub
Copy link
Contributor

@mmgxa - is it preferable for you to track the per-step results and reduce them in on_train_epoch_end or on_validation_epoch_end as you please? what are your thoughts on #8690 ?

@tchaton
Copy link
Contributor

tchaton commented Aug 3, 2021

Dear @mmgxa,

I am not sure to follow how you can't get the loss on training_epoch_end.

This seem to work fine.

def test_epoch_end_hooks(tmpdir):

    seed_everything(42)

    class TestModel(BoringModel):

        def training_step(self, batch, batch_idx):
            loss = super().training_step(batch, batch_idx)
            loss["batch_idx"] = batch_idx
            return loss

        def validation_step(self, batch, batch_idx):
            loss = super().training_step(batch, batch_idx)
            loss["batch_idx"] =  -1 * batch_idx
            return loss

        def training_epoch_end(self, outputs) -> None:
            assert sum(x["loss"] for x in outputs).item() == 12.22606086730957
            assert sum(x["batch_idx"] for x in outputs) == sum(range(5))

        def validation_epoch_end(self, outputs) -> None:
            assert sum(x["loss"] for x in outputs).item() == 10.310195922851562
            assert sum(x["batch_idx"] for x in outputs) == -1 * sum(range(3))

    model = TestModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=5,
        limit_val_batches=3,
        num_sanity_val_steps=0,
    )
    trainer.fit(model)

@mmgxa
Copy link
Author

mmgxa commented Aug 4, 2021

@tchaton Not sure what this code does. But take a look at the following two notebooks and please try to explain why the results are different?

https://colab.research.google.com/github/mmg10/pl_bug/blob/main/pl_test_138.ipynb
https://colab.research.google.com/github/mmg10/pl_bug/blob/main/pl_test_140.ipynb

(Both train/valid loss should be the same as in the third cell - which is the case for 1.3.8, but not for 1.4.0. In PL 1.4, the train loss is 0 for first epoch, which is wrong, and in the second epoch, it reports the loss for the second step/batch only!)

@mmgxa
Copy link
Author

mmgxa commented Aug 4, 2021

@ananthsub but on_train_epoch_end doesn't support the output parameter and hence the loss can't be averaged like in training_epoch_end. One needs to add variable in the __init__ to keep track of it.

My thoughts on #8690? Well, since it doesn't make sense to have the training_epoch_end run after the validation steps (just think about the blocks in the first comment) - yeah, it's better to remove them once and for all 😏

@awaelchli
Copy link
Contributor

@mmgxa so the reason you are seeing a different behavior is as you said the hook order changed. You were computing a self.avg_train_loss and then referencing that (printing it) in the validation_epoch_end. The only solution I can suggest right now is to compute a running average directly in the training_step so you will be able to get the value in the validation hooks.

@mmgxa mmgxa closed this as completed Aug 5, 2021
davidgill97 added a commit to davidgill97/LightlySSL that referenced this issue Nov 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on working as intended Working as intended
Projects
None yet
Development

No branches or pull requests

6 participants