Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug and question about logging -- missing epoch, validation before train? #1520

Closed
guydav opened this issue Apr 17, 2020 · 15 comments
Closed
Assignees
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@guydav
Copy link
Contributor

guydav commented Apr 17, 2020

🐛 Bug

First, the clear bug: in TrainerLoggingMixin.log_metrics() the epoch is added to the metrics variable (line 70) which is never accessed again. That should be to scalar_metrics, shouldn't it?

Second, a question: I implemented a very primitive logger (to stdout) and logging to it. I don't get training results when the first epoch ends until after the first epoch validation step, and consequently don't get training metrics from the last epochs. See code and sample output below. Does this make sense?

To Reproduce

Add the following code to a Lightning Module and run a trainer with the following logger:

use_gpu = int(torch.cuda.is_available())
print_logger = PrintLogger()
trainer = Trainer(gpus=use_gpu, max_epochs=5, logger=print_logger)
trainer.fit(model)

Code sample

Minimal logging in the LightningModule:

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        logs = dict(train_loss=avg_loss, train_acc=avg_acc)
        return dict(log=logs)

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        logs = dict(val_loss=avg_loss, val_acc=avg_acc)
        return dict(log=logs)

A minimal logger:

from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only

class PrintLogger(LightningLoggerBase):
    
    def __init__(self):
        super(PrintLogger, self).__init__()
    
    @property
    def name(self):
        return 'Test'
    
    @property
    def experiment(self):
        return self.name()
    
    @property
    def version(self):
        return '0.0.1'
    
    @rank_zero_only
    def log_hyperparams(self, params):
        print(f'Hyperparameters:\n{params}')

    @rank_zero_only
    def log_metrics(self, metrics, step):
        if metrics is not None and len(metrics.keys()) > 0:
            print(f'{step}: {metrics}')

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

Expected behavior

I would expect to see the training output for each epoch followed by the validation output for each epoch, for each of the five epochs. Instead, I see the following -- four train outputs and five validation ones, and seeing the validation first:

Observed behavior:

63: {'val_loss': 0.6922042369842529, 'val_acc': 0.51458740234375}
64: {'train_acc': 0.503265380859375, 'train_loss': 1.0884952545166016}
127: {'val_loss': 0.6919643878936768, 'val_acc': 0.51861572265625}
128: {'train_acc': 0.51318359375, 'train_loss': 0.6927268505096436}
191: {'val_loss': 0.6915570497512817, 'val_acc': 0.526611328125}
192: {'train_acc': 0.5161285400390625, 'train_loss': 0.6924755573272705}
255: {'val_loss': 0.6915992498397827, 'val_acc': 0.52325439453125}
256: {'train_acc': 0.5159149169921875, 'train_loss': 0.6921626329421997}
319: {'val_loss': 0.6915264129638672, 'val_acc': 0.521240234375}

Expected behavior:

Where n is the number of steps/batches per epoch:

n-1: {'train_acc': ..., 'train_loss': ...}
n-1: {'val_loss': ..., 'val_acc': ...}
2n-1 {'train_acc': ..., 'train_loss': ...}
2n-1: {'val_loss': ..., 'val_acc': ...}
3n-1 {'train_acc': ..., 'train_loss': ...}
3n-1: {'val_loss': ..., 'val_acc': ...}
...

Environment

cuda:
	GPU:
	available:           False
	version:             10.0.130
packages:
	numpy:               1.18.2
	pyTorch_debug:       False
	pyTorch_version:     1.3.1
	pytorch-lightning:   0.7.3
	tensorboard:         2.2.0
	tqdm:                4.45.0
system:
	OS:                  Linux
	architecture: 64bit
	processor:           x86_64
	python:              3.7.4
	version:             #1 SMP Tue Feb 4 23:02:59 UTC 2020

Additional context

@guydav guydav added bug Something isn't working help wanted Open to be worked on labels Apr 17, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Apr 17, 2020

Just to clarify your bug, you are missing some metrics? If I get it correctly, this shall be your fix #1459

@guydav
Copy link
Contributor Author

guydav commented Apr 17, 2020

Hi @Borda , I don't think that captures it. I'm reporting one clear bug (the epoch field doesn't get to the logger, since it's added to the metrics dict, rather than the scalar_metrics dict, in TrainerLoggingMixin.log_metrics().

The second bug I'm reporting may not be a bug, but at the very least is confusing behavior. As a user, I would expect my logger to first get training results for an epoch, and then validation results. The PR you're referencing might solve the problem of the last training results not being returned at all, but won't fix the fact that something in the way steps are used in the logger causes validation results to come before test results.

@guydav
Copy link
Contributor Author

guydav commented Apr 17, 2020

I updated under 'Expected Behavior' to clarify the discrepancy I found.

@Borda
Copy link
Member

Borda commented Apr 17, 2020

@guydav mind send a PR? :]

@guydav
Copy link
Contributor Author

guydav commented Apr 17, 2020

@Borda for the first issue, sure, it's a one-line fix, but I'd need to set up to be able to run the tests.

For the second issue, I don't quite know where it originates from. I was hoping someone with a better understanding of how logging is structured chimes in before I start to try and learn my way around that entire codebase.

@stale
Copy link

stale bot commented Jun 16, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Jun 16, 2020
@guydav
Copy link
Contributor Author

guydav commented Jun 17, 2020

Hi @Borda -- I think there's still another underlying issue with how metrics are reported unless someone interacted with this code over the last while.

If you look at the observed behavior and expected behavior steps above, do you agree there's an issue? That it makes no sense for the validation metrics for a particular epoch to report before the training metrics for that epoch, and with a different step at that?

@stale stale bot removed the won't fix This will not be worked on label Jun 17, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Jul 14, 2020

This behaviour seems to have changed.
I copied your logger to the pl_examples/gpu_template.py and launched it with

    trainer = Trainer(
        max_epochs=2, 
        gpus=1,
        logger=PrintLogger(),
        limit_train_batches=10,
        limit_val_batches=10,
        row_log_interval=1,
        progress_bar_refresh_rate=0
    )

output is

0: {'train_loss': 2.503892660140991, 'epoch': 0}
1: {'train_loss': 2.096820831298828, 'epoch': 0}
2: {'train_loss': 8.215052604675293, 'epoch': 0}
3: {'train_loss': 5.370606422424316, 'epoch': 0}
4: {'train_loss': 5.988080978393555, 'epoch': 0}
5: {'train_loss': 2.3805108070373535, 'epoch': 0}
6: {'train_loss': 4.3501176834106445, 'epoch': 0}
7: {'train_loss': 9.668755531311035, 'epoch': 0}
8: {'train_loss': 6.58243465423584, 'epoch': 0}

# this is the last step of the epoch, metrics get combined and logged together
9: {'epoch': 0.0, 'val_loss': 4.287566661834717, 'train_loss': 12.217967987060547, 'val_acc': 0.515625}

10: {'train_loss': 1.7836229801177979, 'epoch': 1}
11: {'train_loss': 1.7488218545913696, 'epoch': 1}
12: {'train_loss': 2.221280097961426, 'epoch': 1}
13: {'train_loss': 3.4499270915985107, 'epoch': 1}
14: {'train_loss': 3.5983619689941406, 'epoch': 1}
15: {'train_loss': 2.813007116317749, 'epoch': 1}
16: {'train_loss': 3.2659897804260254, 'epoch': 1}
17: {'train_loss': 4.156956672668457, 'epoch': 1}
18: {'train_loss': 2.931321859359741, 'epoch': 1}
# no val logs here :( we expect a dict as in step 9

The original problem you describe seems to be gone, but I notice two other issues:

  • at step 9, the epoch is a float 0.0
  • the validation metrics of epoch 2 (last one) do not get logged

@guydav
Copy link
Contributor Author

guydav commented Jul 15, 2020

@awaelchli, thank you for looking into it again! I agree this does look better. It's been a while since I dug through this code, but I think I have a hunch for at least one of these issues. Note that in the second issue you point out, we're missing the last set of train metrics, that should arrive with the validation metrics.

Reading through LightningLoggerBase, its API seems to be through the function agg_and_log_metrics, which is called from the TrainerLoggingMixin. agg_and_log_metrics calls _aggregate_metrics, which only omits metrics to log if the current step is different from the previous step. Since there's never a call to this function after the last validation epoch, it doesn't see a new step, and therefore never omits the last output. It looks like a call to any of save, finalize, or close should result in a call to finalize_agg_metrics, which should do the trick. So either it's not getting called, or for some reason, it doesn't do what it should.

The float epoch thing is probably a smaller bit. I'll try to debug both of these later today if I have time.

@awaelchli
Copy link
Contributor

@guydav I checked again, the step 8 was missing from my post by accident because I had to copy paste around some warnings that were printed to the console and it seems I missed one line, but the step 8 is there and I edited my post.

@guydav
Copy link
Contributor Author

guydav commented Jul 15, 2020

Oh, I mean that we're missing step 19, which contains both the 10th training batch from the second epoch and the validation metrics for that epoch.

@awaelchli
Copy link
Contributor

Yes I agree, that's the big one :) It should definitely log a dict like step 9

@awaelchli awaelchli self-assigned this Jul 15, 2020
@guydav
Copy link
Contributor Author

guydav commented Jul 18, 2020

Update: it appears that I am the problem. I don't know why, but I overrode save and finalize above to do nothing. Omitting that (commenting them out, or a super call) makes everything work out. I honestly have no idea how that happened, but now everything looks fine. I don't see the floating point epoch either:

Here's a printout:

0: {'train_loss': 2.517963171005249, 'epoch': 0}
1: {'train_loss': 2.1298298835754395, 'epoch': 0}
2: {'train_loss': 8.561811447143555, 'epoch': 0}
3: {'train_loss': 5.23430871963501, 'epoch': 0}
4: {'train_loss': 6.442159175872803, 'epoch': 0}
5: {'train_loss': 2.1811487674713135, 'epoch': 0}
6: {'train_loss': 4.158588409423828, 'epoch': 0}
7: {'train_loss': 10.028255462646484, 'epoch': 0}
8: {'train_loss': 6.593491077423096, 'epoch': 0}
9: {'val_loss': 4.531818389892578, 'val_acc': 0.453125, 'epoch': 0}
9: {'train_loss': 10.541756629943848, 'epoch': 0}
10: {'train_loss': 1.6655378341674805, 'epoch': 1}
11: {'train_loss': 2.284700393676758, 'epoch': 1}
12: {'train_loss': 2.4957871437072754, 'epoch': 1}
13: {'train_loss': 4.456875324249268, 'epoch': 1}
14: {'train_loss': 4.337017059326172, 'epoch': 1}
15: {'train_loss': 3.4667391777038574, 'epoch': 1}
16: {'train_loss': 3.3742592334747314, 'epoch': 1}
17: {'train_loss': 3.353729248046875, 'epoch': 1}
18: {'train_loss': 2.8706002235412598, 'epoch': 1}
19: {'val_loss': 4.367581844329834, 'val_acc': 0.571875, 'epoch': 1}
19: {'train_loss': 4.163558483123779, 'epoch': 1}

@guydav guydav closed this as completed Jul 18, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Jul 18, 2020

Oh great you found this. Last time I tried to debug it I was stuck because we actually have tests for these things and I was very confused why it would not work :)
Thanks @guydav

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

3 participants