Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zoneout LSTM Cell #1877

Closed
wants to merge 1 commit into from

Conversation

failure-to-thrive
Copy link
Contributor

@failure-to-thrive failure-to-thrive commented May 25, 2020

Closes #1867

@bot-of-gabrieldemarmiesse

@pedrolarben

You are owner of some files modified in this pull request.
Would you kindly review the changes whenever you have the time to?
Thank you very much.

@pedrolarben
Copy link
Contributor

Sure, I will do it asap

Copy link
Contributor

@pedrolarben pedrolarben left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your contribution, please check the suggested changes :)

"""LSTM cell with recurrent zoneout.

https://arxiv.org/abs/1606.01305
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, add the description of the arguments of the class (units, zoneout_h, zoneout_c, seed) in the docstring

**kwargs
):
"""
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this empty docstring

)
return dt * t + (1 - dt) * tm1

def call(self, inputs, states, training=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to get the training value when it's None, try first with base_layer_utils.call_context().training, which is a more specific approach. Then, if not found, use keras backend as you were doing.
You can use this as a reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does your approach outperform the current one? It's just a fallback. Your fallback still requires that fallback. :) I'm even not sure that we need it here but just peeked it from Tensorflow's Dropout implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the call_context function seemed to be a more specific function. The approach I suggested is the one used by default, it is implemented here: in_train_phase.
However, I see your reference from Dropout implementation. Seems ok for me, thanks again :).


def _zoneout(self, t, tm1, rate, training):
dt = tf.cast(
tf.random.uniform(t.shape, seed=self.seed) >= rate * training, t.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of multiplying by training, it could be surrounded by and if/else statement so that we avoid unnecessary computation (the random matrix)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand things correctly, 99% of the execution time is a training. No harm at all but looks much nicer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with you that it looks much nicer with the multiplications. But I am concerned that it could harm prediction performance. Sometimes we forget about the importance of predicting time as most of the trained models don't end in productions, but the main goal of training models is to use it in productions and therefore prediction time matters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance of production code is a much more broader topic. If someone report that that redundant random dramatically degrades their model inference to an unacceptable level, I will definitely try to find the best possible solution. Otherwise, I see no reasons to change it now and break something accidentally.
Also, just for note, branching is a bad thing for performance. Especially in the GPU world.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ok, we can leave it in this more elegant way and take into account if someone reports any performance issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that training can be a tensor, which means if check will yield error. In the case you want to disable the zone out in the inference mode, please use tf.cond. Please check the keras.backend.dropout as an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tf.cond is phasing out now: #1794

@@ -558,3 +558,64 @@ def test_esn_config():
restored_cell = rnn_cell.ESNCell.from_config(config)
restored_config = restored_cell.get_config()
assert config == restored_config


def test_zoneout():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, add another test case for an e2e test with keras model and compile/fit. You can take test_esn_keras_rnn_e2e as an example.

np.testing.assert_allclose(lstm_output, zoneout_output)


def test_zoneout_config():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should test if you can create the cell using from_config function.
It could be something like this:
restored_cell = rnn_cell.ZoneoutLSTMCell.from_config(config)
restored_config = restored_cell.get_config()
assert config == restored_config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not override from_config() nor any of its helpers, why do we have to test it? When it could fail and how can we help if it fails? IMHO it's not a child class authors responsibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the inherited method from_config() is going to fail, it just runs cls(**config), where cls is ZoneoutLSTMCell, so it is calling the constructor. In this case, I'm sure it will not fail but it is better to have good tests implemented so that we avoid future mistakes while modifying this class (adding new arguments or changes in the get_config).

A possible case when testing the from_config method can fail is if an argument from the constructor is not well defined in the get_config method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember, the only thing that may happen to a public class is an addition of an optional argument. Any other updates break backward compatibility and must be prohibited in the first place. At the end of the all, it's all about maintaining appropriate tests. Sorry, I personally see no reasons why to complicate the things.

Copy link
Contributor

@pedrolarben pedrolarben May 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not talking about modifications of the LSTM layer but the child ZoneoutLSTM.
I personally see no complication on adding those three lines that assure the complete get_config-from_config process works fine.
Apart from its utility, which I really believe in, It has been tested for every other recurrent cell in tensorflow_addons, so you can take consistency as another reason.

config = cell.get_config()

our_config = {
"units": 3, # just to ensure that we do get the base's class config

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me explain. As opposed to the common approach, I do not test base's class config. Only our child class' values. We are not responsible for parent class' values. However, since we override get_config() we still have to be sure that parent class' values got back too. Testing for units should be enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's clear for me, thanks! :)

Copy link
Contributor

@pedrolarben pedrolarben left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered some comments, please take a look at them


def _zoneout(self, t, tm1, rate, training):
dt = tf.cast(
tf.random.uniform(t.shape, seed=self.seed) >= rate * training, t.dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that training can be a tensor, which means if check will yield error. In the case you want to disable the zone out in the inference mode, please use tf.cond. Please check the keras.backend.dropout as an example.

def call(self, inputs, states, training=None):
if training is None:
training = keras.backend.learning_phase()
output, new_states = super().call(inputs, states, training)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the output is not used here, probably replace it with _.

@@ -153,8 +153,8 @@
/tensorflow_addons/optimizers/yogi.py @manzilz
/tensorflow_addons/optimizers/tests/yogi_test.py @manzilz

/tensorflow_addons/rnn/cell.py @qlzh727 @pedrolarben
/tensorflow_addons/rnn/tests/cell_test.py @qlzh727 @pedrolarben
/tensorflow_addons/rnn/cell.py @qlzh727 @pedrolarben @failure-to-thrive
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably want to split the file in future, so that the PR/issue will be forward to the correct person.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it would be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably want to split the file in future, so that the PR/issue will be forward to the correct person.

@gabrieldemarmiesse, could you please take a look?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, not all cell classes should be in the same file. We can move the code of zoneoutLSTM in another file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify the process, let's merge that PR and split the cell.py and cell_test.py thereafter. Those files contains lots of different cells already.

@Mushoz
Copy link

Mushoz commented Jan 4, 2023

Any reason why this was never merged?

@pedrolarben
Copy link
Contributor

Any reason why this was never merged?

There are some comments that need to be addressed and this is now outdated, as cell.py has been refactored and split into different files: https://github.com/tensorflow/addons/tree/master/tensorflow_addons/rnn

@Mushoz
Copy link

Mushoz commented Jan 5, 2023

@failure-to-thrive Any chance this could be picked up again? Or do you reckon the interest of the community at large is insufficient for this to be a worthwhile pursuit? I, for one, would love to play around with it, to see if I can find improvements over the vanilla LSTM implementation.

@failure-to-thrive
Copy link
Contributor Author

Or do you reckon the interest of the community at large is insufficient for this to be a worthwhile pursuit?

Yes, while it's technically possible to remaster, there is an inclusion criteria outlined at the link above.

@seanpmorgan
Copy link
Member

Thank you for your contribution. We sincerely apologize for any delay in reviewing, but TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision:
TensorFlow Addons Wind Down

Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA:
Keras
Keras-CV
Keras-NLP

@seanpmorgan seanpmorgan closed this Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

want zoneout lstm supported
8 participants