Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor callbacks #776

Merged
merged 13 commits into from
Feb 16, 2020

Conversation

kuynzereb
Copy link
Contributor

Some refactoring of callbacks is done in this PR:

  1. Added on_validation_begin() and on_valdiation_end() to callbacks. The point is that ModelCheckpoint uses on_epoch_end() but actually is called in the validation end. It is fixed.
  2. All callbacks calls are unified and take no additional arguments. That is, from now on it is on_epoch_end() instead of on_epoch_end(epoch, logs). Instead of these additional arguments now callback will have a link to the trainer, so it will have access to the current_epoch, global_step, callback_metrics and so on. We just need to call self.callback.set_trainer(self) while initializing callbacks in the trainer.

With these modifications it will be easy to implement such things as additional checkpointing callback which uses on_epoch_end instead of on_validation_end and so can be used to checkpointing training with no validation loops (#596, #652). Also it will be easy to start using global_step in checkpoints name. And in general all callbacks will be more unified.

@Borda Borda added feature Is an improvement or enhancement help wanted Open to be worked on labels Jan 31, 2020
@Borda Borda added this to the 0.6.1 milestone Jan 31, 2020
@williamFalcon
Copy link
Contributor

@Borda @neggert good to go?

@Borda
Copy link
Member

Borda commented Feb 1, 2020

it is quite extensive change, ill have look tomorrow...

pytorch_lightning/callbacks/pt_callbacks.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/pt_callbacks.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/pt_callbacks.py Show resolved Hide resolved
pytorch_lightning/callbacks/pt_callbacks.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/pt_callbacks.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/pt_callbacks.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/pt_callbacks.py Show resolved Hide resolved
@hadim hadim mentioned this pull request Feb 10, 2020
@hadim
Copy link
Contributor

hadim commented Feb 11, 2020

I am planning to move the progress bar to a callback and I'll need this PR. Here are a few suggestions.

  • could we have a test_start and test_end?
  • could we move each callbacks in a separate file so maintenance and history are easier?
  • Could we unify how callbacks are passed to the trainer? It would be useful to generalize and pass a list of callbacks instead of individual callbacks. During each event, the trainer should call all the callbacks. Free to the callback to add the necessary logic to do something or not.

Let me know what do you think.

@Borda
Copy link
Member

Borda commented Feb 11, 2020

I am planning to move the progress bar to a callback and I'll need this PR. Here are a few suggestions.

  • could we have a test_start and test_end?

Yes

  • could we move each callbacks in a separate file so maintenance and history are easier?

Good one :]

  • Could we unify how callbacks are passed to the trainer? It would be useful to generalize and pass a list of callbacks instead of individual callbacks. During each event, the trainer should call all the callbacks. Free to the callback to add the necessary logic to do something or not.

That would be nice...

Let me know what do you think.

@williamFalcon ^^

@kuynzereb
Copy link
Contributor Author

I can make first 2 suggestions. And I totally agree with the third suggestion and I thought about it myself. But I am afraid it is not very straightforward with the current implementation of callbacks and now I don't have enough time to delve into it. So I would like to implement only the first two suggestions in this PR.

@hadim
Copy link
Contributor

hadim commented Feb 11, 2020

That would be great. I make no promise but I can try to tackle the third one.

@Borda
Copy link
Member

Borda commented Feb 11, 2020

@kuynzereb @hadim I think that it would be much easier to pass if each suggestion will be single PR... the large/complex PRs are not so nice for review (takes much longer to check) nor the author (debugging may become quite complex) :]

@kuynzereb
Copy link
Contributor Author

@Borda it also sounds very reasonable :)

@hadim
Copy link
Contributor

hadim commented Feb 11, 2020

No problem to make PRs but something like test_start and test_end could easily be added to this one without too much burden I think.

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really valuable contribution, just a few (rather formatting) comments


class Callback(object):
r"""Abstract base class used to build new callbacks.
"""
"""Abstract base class used to build new callbacks."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as abstract, inherit from abstract ABC... class Callback(ABC): ?
but if it will be ABC, then you have to implement all methods all the time... :/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with all this ABC stuff, so I don't really know :)
But anyway I think it would be better to do it in another PR (if it really should be done)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass


_no_trainer_error_msg = ".set_trainer() should be called after the callback initialization"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_NO_TRAINER_ERROR_MSG = "Missing trainer instance. The `.set_trainer(...)` should be called after the callback initialization."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self, epoch, logs=None):
def on_epoch_end(self):
assert self._trainer is not None, _no_trainer_error_msg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this assert to a parent class and just at the beginning call parent method ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. Theoretically there may be some calls or even entire callbacks which don't use trainer at all. So there will be no need in this assert.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true... lets keep it for now

@@ -231,8 +231,12 @@ def mock_save_function(filepath):
# CASE K=-1 (all)
w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1)
w.save_function = mock_save_function
trainer = Trainer()
w.set_trainer(trainer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use longer var name then one letter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})
w._trainer.current_epoch = i
w._trainer.callback_metrics = {'val_loss': loss}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment that this is kind of hack to simulate training...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@Borda
Copy link
Member

Borda commented Feb 13, 2020

@kuynzereb may you push a commit so it triggers new GitHub CI

@Borda Borda added the ready PRs ready to be merged label Feb 14, 2020
@Borda
Copy link
Member

Borda commented Feb 14, 2020

Great work, thx

@Borda
Copy link
Member

Borda commented Feb 14, 2020

@williamFalcon it was nice to see that GH actions were about twice faster than Travis lol

@williamFalcon
Copy link
Contributor

@Borda can you close the appropriate tickets related to this PR?

@Borda
Copy link
Member

Borda commented Feb 16, 2020

I don't see any particular issue for this, but I will check the backlog later... @kuynzereb was this change requested in a issue?

@kuynzereb
Copy link
Contributor Author

@kuynzereb was this change requested in a issue?

Nope, it was not

@kuynzereb kuynzereb deleted the new_callback_entry_points branch February 16, 2020 08:50
@jeremyjordan
Copy link
Contributor

@kuynzereb you mentioned:

All callbacks calls are unified and take no additional arguments. That is, from now on it is on_epoch_end() instead of on_epoch_end(epoch, logs). Instead of these additional arguments now callback will have a link to the trainer, so it will have access to the current_epoch, global_step, callback_metrics and so on.

I'm wondering, was this on_batch_start intended to still have a batch argument?

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_loop.py#L462

@kuynzereb
Copy link
Contributor Author

@jeremyjordan, well, it is not a callback, but a model hook. And actually I didn't think about them while doing this PR. But yes, it seems that the concept of model hooks is very similar to the concept of callbacks. So maybe it should be unified as well.

@jeremyjordan
Copy link
Contributor

ah yes, that's right. i was going through training_loop.py quickly searching for on_* methods to fix a merge conflict and didn't stop to think about the difference between the two.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants