-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zoneout LSTM Cell #1877
Zoneout LSTM Cell #1877
Conversation
You are owner of some files modified in this pull request. |
Sure, I will do it asap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your contribution, please check the suggested changes :)
"""LSTM cell with recurrent zoneout. | ||
|
||
https://arxiv.org/abs/1606.01305 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, add the description of the arguments of the class (units, zoneout_h, zoneout_c, seed) in the docstring
**kwargs | ||
): | ||
""" | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this empty docstring
) | ||
return dt * t + (1 - dt) * tm1 | ||
|
||
def call(self, inputs, states, training=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to get the training
value when it's None, try first with base_layer_utils.call_context().training
, which is a more specific approach. Then, if not found, use keras backend as you were doing.
You can use this as a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does your approach outperform the current one? It's just a fallback. Your fallback still requires that fallback. :) I'm even not sure that we need it here but just peeked it from Tensorflow's Dropout implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the call_context
function seemed to be a more specific function. The approach I suggested is the one used by default, it is implemented here: in_train_phase.
However, I see your reference from Dropout implementation. Seems ok for me, thanks again :).
|
||
def _zoneout(self, t, tm1, rate, training): | ||
dt = tf.cast( | ||
tf.random.uniform(t.shape, seed=self.seed) >= rate * training, t.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of multiplying by training
, it could be surrounded by and if/else statement so that we avoid unnecessary computation (the random matrix)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand things correctly, 99% of the execution time is a training. No harm at all but looks much nicer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally agree with you that it looks much nicer with the multiplications. But I am concerned that it could harm prediction performance. Sometimes we forget about the importance of predicting time as most of the trained models don't end in productions, but the main goal of training models is to use it in productions and therefore prediction time matters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance of production code is a much more broader topic. If someone report that that redundant random dramatically degrades their model inference to an unacceptable level, I will definitely try to find the best possible solution. Otherwise, I see no reasons to change it now and break something accidentally.
Also, just for note, branching is a bad thing for performance. Especially in the GPU world.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's ok, we can leave it in this more elegant way and take into account if someone reports any performance issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note that training can be a tensor, which means if check will yield error. In the case you want to disable the zone out in the inference mode, please use tf.cond. Please check the keras.backend.dropout as an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.cond
is phasing out now: #1794
@@ -558,3 +558,64 @@ def test_esn_config(): | |||
restored_cell = rnn_cell.ESNCell.from_config(config) | |||
restored_config = restored_cell.get_config() | |||
assert config == restored_config | |||
|
|||
|
|||
def test_zoneout(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, add another test case for an e2e test with keras model and compile/fit. You can take test_esn_keras_rnn_e2e
as an example.
np.testing.assert_allclose(lstm_output, zoneout_output) | ||
|
||
|
||
def test_zoneout_config(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should test if you can create the cell using from_config
function.
It could be something like this:
restored_cell = rnn_cell.ZoneoutLSTMCell.from_config(config)
restored_config = restored_cell.get_config()
assert config == restored_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not override from_config()
nor any of its helpers, why do we have to test it? When it could fail and how can we help if it fails? IMHO it's not a child class authors responsibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the inherited method from_config()
is going to fail, it just runs cls(**config)
, where cls
is ZoneoutLSTMCell
, so it is calling the constructor. In this case, I'm sure it will not fail but it is better to have good tests implemented so that we avoid future mistakes while modifying this class (adding new arguments or changes in the get_config
).
A possible case when testing the from_config
method can fail is if an argument from the constructor is not well defined in the get_config
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember, the only thing that may happen to a public class is an addition of an optional argument. Any other updates break backward compatibility and must be prohibited in the first place. At the end of the all, it's all about maintaining appropriate tests. Sorry, I personally see no reasons why to complicate the things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was not talking about modifications of the LSTM layer but the child ZoneoutLSTM.
I personally see no complication on adding those three lines that assure the complete get_config-from_config
process works fine.
Apart from its utility, which I really believe in, It has been tested for every other recurrent cell in tensorflow_addons, so you can take consistency as another reason.
config = cell.get_config() | ||
|
||
our_config = { | ||
"units": 3, # just to ensure that we do get the base's class config |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me explain. As opposed to the common approach, I do not test base's class config. Only our child class' values. We are not responsible for parent class' values. However, since we override get_config()
we still have to be sure that parent class' values got back too. Testing for units
should be enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it's clear for me, thanks! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answered some comments, please take a look at them
|
||
def _zoneout(self, t, tm1, rate, training): | ||
dt = tf.cast( | ||
tf.random.uniform(t.shape, seed=self.seed) >= rate * training, t.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note that training can be a tensor, which means if check will yield error. In the case you want to disable the zone out in the inference mode, please use tf.cond. Please check the keras.backend.dropout as an example.
def call(self, inputs, states, training=None): | ||
if training is None: | ||
training = keras.backend.learning_phase() | ||
output, new_states = super().call(inputs, states, training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the output is not used here, probably replace it with _.
@@ -153,8 +153,8 @@ | |||
/tensorflow_addons/optimizers/yogi.py @manzilz | |||
/tensorflow_addons/optimizers/tests/yogi_test.py @manzilz | |||
|
|||
/tensorflow_addons/rnn/cell.py @qlzh727 @pedrolarben | |||
/tensorflow_addons/rnn/tests/cell_test.py @qlzh727 @pedrolarben | |||
/tensorflow_addons/rnn/cell.py @qlzh727 @pedrolarben @failure-to-thrive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably want to split the file in future, so that the PR/issue will be forward to the correct person.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it would be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably want to split the file in future, so that the PR/issue will be forward to the correct person.
@gabrieldemarmiesse, could you please take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, not all cell classes should be in the same file. We can move the code of zoneoutLSTM in another file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To simplify the process, let's merge that PR and split the cell.py
and cell_test.py
thereafter. Those files contains lots of different cells already.
Any reason why this was never merged? |
There are some comments that need to be addressed and this is now outdated, as cell.py has been refactored and split into different files: https://github.com/tensorflow/addons/tree/master/tensorflow_addons/rnn |
@failure-to-thrive Any chance this could be picked up again? Or do you reckon the interest of the community at large is insufficient for this to be a worthwhile pursuit? I, for one, would love to play around with it, to see if I can find improvements over the vanilla LSTM implementation. |
Yes, while it's technically possible to remaster, there is an inclusion criteria outlined at the link above. |
Thank you for your contribution. We sincerely apologize for any delay in reviewing, but TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: |
Closes #1867