-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor callbacks #776
Refactor callbacks #776
Conversation
it is quite extensive change, ill have look tomorrow... |
I am planning to move the progress bar to a callback and I'll need this PR. Here are a few suggestions.
Let me know what do you think. |
Yes
Good one :]
That would be nice...
|
I can make first 2 suggestions. And I totally agree with the third suggestion and I thought about it myself. But I am afraid it is not very straightforward with the current implementation of callbacks and now I don't have enough time to delve into it. So I would like to implement only the first two suggestions in this PR. |
That would be great. I make no promise but I can try to tackle the third one. |
@kuynzereb @hadim I think that it would be much easier to pass if each suggestion will be single PR... the large/complex PRs are not so nice for review (takes much longer to check) nor the author (debugging may become quite complex) :] |
@Borda it also sounds very reasonable :) |
No problem to make PRs but something like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really valuable contribution, just a few (rather formatting) comments
|
||
class Callback(object): | ||
r"""Abstract base class used to build new callbacks. | ||
""" | ||
"""Abstract base class used to build new callbacks.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as abstract, inherit from abstract ABC... class Callback(ABC):
?
but if it will be ABC, then you have to implement all methods all the time... :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not familiar with all this ABC
stuff, so I don't really know :)
But anyway I think it would be better to do it in another PR (if it really should be done)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass | ||
|
||
|
||
_no_trainer_error_msg = ".set_trainer() should be called after the callback initialization" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_NO_TRAINER_ERROR_MSG = "Missing trainer instance. The `.set_trainer(...)` should be called after the callback initialization."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# Allow instances to be re-used | ||
self.wait = 0 | ||
self.stopped_epoch = 0 | ||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
def on_epoch_end(self): | ||
assert self._trainer is not None, _no_trainer_error_msg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this assert to a parent class and just at the beginning call parent method ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this. Theoretically there may be some calls or even entire callbacks which don't use trainer at all. So there will be no need in this assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true... lets keep it for now
tests/test_trainer.py
Outdated
@@ -231,8 +231,12 @@ def mock_save_function(filepath): | |||
# CASE K=-1 (all) | |||
w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) | |||
w.save_function = mock_save_function | |||
trainer = Trainer() | |||
w.set_trainer(trainer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use longer var name then one letter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/test_trainer.py
Outdated
for i, loss in enumerate(losses): | ||
w.on_epoch_end(i, logs={'val_loss': loss}) | ||
w._trainer.current_epoch = i | ||
w._trainer.callback_metrics = {'val_loss': loss} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment that this is kind of hack to simulate training...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@kuynzereb may you push a commit so it triggers new GitHub CI |
Great work, thx |
@williamFalcon it was nice to see that GH actions were about twice faster than Travis lol |
@Borda can you close the appropriate tickets related to this PR? |
I don't see any particular issue for this, but I will check the backlog later... @kuynzereb was this change requested in a issue? |
Nope, it was not |
@kuynzereb you mentioned:
I'm wondering, was this |
@jeremyjordan, well, it is not a callback, but a model hook. And actually I didn't think about them while doing this PR. But yes, it seems that the concept of model hooks is very similar to the concept of callbacks. So maybe it should be unified as well. |
ah yes, that's right. i was going through |
Some refactoring of callbacks is done in this PR:
on_validation_begin()
andon_valdiation_end()
to callbacks. The point is thatModelCheckpoint
useson_epoch_end()
but actually is called in the validation end. It is fixed.on_epoch_end()
instead ofon_epoch_end(epoch, logs)
. Instead of these additional arguments now callback will have a link to the trainer, so it will have access to thecurrent_epoch
,global_step
,callback_metrics
and so on. We just need to callself.callback.set_trainer(self)
while initializing callbacks in the trainer.With these modifications it will be easy to implement such things as additional checkpointing callback which uses
on_epoch_end
instead ofon_validation_end
and so can be used to checkpointing training with no validation loops (#596, #652). Also it will be easy to start usingglobal_step
in checkpoints name. And in general all callbacks will be more unified.