Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First attempt at Yarin Gal dropout for LSTM #182

Merged
merged 7 commits into from
Jan 23, 2017
Merged

First attempt at Yarin Gal dropout for LSTM #182

merged 7 commits into from
Jan 23, 2017

Conversation

yoavg
Copy link
Contributor

@yoavg yoavg commented Nov 21, 2016

https://arxiv.org/pdf/1512.05287v5.pdf

I'm not 100% sure its correct, and it has some ugliness -- LSTMBuilder now keeps a pointer to ComputationGraph -- but Gal's dropout seems to be the preferred way to do dropout for LSTMs.

Will appreciate another pair of eyes.

@pmichel31415
Copy link
Collaborator

If you want to avoid linking to cg maybe you can do something like using a flag dropout_initialized which is set to false by start_new_sequence_impl

Then in add_input_impl you can check the flag and initialize the mask.

Not sure if that's considered less ugly but at least you don't need an additional pointer

@yoavg
Copy link
Contributor Author

yoavg commented Nov 21, 2016

I need the CG to create the dropout masks.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 21, 2016

Ah! Got it now. Not sure it's less ugly..

@pmichel31415
Copy link
Collaborator

Another solution would be to overload the start_new_sequence interface to accept cg as a parameter. This could be useful for other stuff as well

Copy link
Contributor

@neubig neubig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general modulo a few comments. Actually I don't think saving the CG is all that ugly, looks like a reasonable solution.

@@ -45,12 +45,20 @@ struct LSTMBuilder : public RNNBuilder {
Expression set_h_impl(int prev, const std::vector<Expression>& h_new) override;
Expression set_s_impl(int prev, const std::vector<Expression>& s_new) override;
public:
bool gal_dropout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed this, but it looks like this isn't initialized to a default value anywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 85 of lstm.cc -- but yeah this is a bit hacky as it disables the previous dropout behavior.
which leads me to the next question: should we disable the previous dropout behavior (and remove it from the code) or keep supporting both?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this does seem a bit hacky. I think we need to decide whether we're going to disable the previous dropout behavior or not. If we'll disable we can simplify the code, and if we need both it'll have to be a flag doing the intuitive switch between the two dropout methods. I'd prefer the simplicity of the former, but maybe we need empirical results. I can try to run some experiments testing the new method if that'd be helpful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you have some ready-to-run benchmarks you trust, that'd be great!

// YG init dropout masks for each layer
masks.clear();
for (unsigned i = 0; i < layers; ++i) {
std::vector<Expression> masks_i;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indents should be two spaces (throughout)

@neubig
Copy link
Contributor

neubig commented Nov 22, 2016

So I tested this and the results look OK on training, but they're horrible on test. During test time we don't apply dropout, and unlike the original "dropout" node (which compensates for this by multiplying the scale of the remaining nodes by 1/dropout_rate), this is just using random_bernoulli masks without this scaling. I haven't read the original paper closely enough to know how they fixed this discrepancy, but I imagine that we'll have to do this here as well.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 22, 2016

Right! good catch.

I now follow Gal's implementation and if we are in "test_mode" we scale by dropout rate (which is actually the rate of units we keep). This means you should set builder.test_mode = true before testing.
Let me know if this works for you. And we can discuss a cleaner API later :)

@redpony
Copy link
Member

redpony commented Nov 22, 2016

Two thoughts:

  1. regarding telling the thing whether it's test time or not, we already
    handle this with regular dropout, so we should make sure to use a unified
    strategy.
  2. while this is still fresh in your mind, any chance we could get zoneout (
    https://arxiv.org/pdf/1606.01305v2.pdf) too? That's the cool new thing it
    seems.

On Tue, Nov 22, 2016 at 9:56 AM, Yoav Goldberg [email protected]
wrote:

I followed Gal's implementation and if we are in "test_mode" we scale by
dropout rate (which is actually the rate of units we keep). This means you
should set builder.test_mode = true before testing.
Let me know if this works for you. And we can discuss a cleaner API later
:)


You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
#182 (comment), or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAJbaxg7sV3mi6PX44kJbAp7w86UDQq3ks5rArwwgaJpZM4K4fkq
.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 22, 2016

What do we do for test time in dropout in RNNs? Or do we just scale the weights already at train time so it's not needed in test? (if it's the second, I'd rather do the testing on current version, and only then switch to that solution and verify they are indeed equivalent)

BTW, are the exact equations that the dynet LSTM implement written somewhere (not in code)?

Re Zoneout: sure I can give it a shot, but we really need to figure out how to support multiple regularization strategies in the code, at least at the API level. This really becomes messy once we have more than one or two. (perhaps duplicate the builder and have diff builder for each dropout variant?)

@neubig
Copy link
Contributor

neubig commented Nov 22, 2016

API problems and zoneout non-withstanding, I'll take a look at this now with the current API.

@neubig
Copy link
Contributor

neubig commented Nov 22, 2016

P.S., thanks to @pmichel31415 the LSTM equations are "documented" here: #154
(which should be added to the official doc eventually)

@neubig
Copy link
Contributor

neubig commented Nov 23, 2016

FYI, my initial tests on an attentional model over the "train-big" set in my nmt-tips repository (https://github.com/neubig/nmt-tips) showed results that look promising for Gal, but perhaps too bad for traditional dropout. I'll investigate this a bit more.

Standard 0.0: dev=7.57288
Standard 0.2: dev=7.79786
Standard 0.5: dev=8.76879
Standard 0.8: ???

Gal 0.0: dev=7.96906
Gal 0.2: ???
Gal 0.5: dev=7.12243
Gal 0.8: dev=7.43567

Also, one thing that caught me here was that in the new implementation the "dropout_rate" variable seems to actually to have the semantics of "retention rate," which is pretty confusing (the current implementation in master has the standard semantics). We should probably do a global search and replace for everywhere where the dropout_rate is used, and replace it with 1-dropout_rate, right?

@yoavg
Copy link
Contributor Author

yoavg commented Nov 23, 2016

yes, I agree that the "retention rate" semantics is confusing -- the place to fix this would be in start_new_sequence_impl. Setting 1-dropout_rate everywhere in that function should work.

note that if you are using this branch as is, then you effectively have only Gal dropout, so you need to test the regular dropout from the master branch if you don't do so already.

@neubig
Copy link
Contributor

neubig commented Nov 23, 2016

Yes, my "standard" numbers reflect the master branch. I'll tell you when I've taken a closer look at why the standard numbers are getting counter-intuitive results.

@neubig
Copy link
Contributor

neubig commented Nov 25, 2016

OK, here are the results for two runs (the dropout rate listed in the results is the actual dropout rate, not retention rate). Looks like Gal dropout is convincingly better almost always, so I'm in favor of switching the default to Gal dropout. I do really like the ability to add dropout+scaling at training time, and just use the parameters at testing time though, so hopefully we can make that work here too.

** Run 1
Standard 0.0: dev=7.57288, train=4.40813, epoch=3
Standard 0.2: dev=7.79786, train=5.4612, epoch=3
Standard 0.5: dev=8.76879, train=5.8941, epoch=7
Standard 0.8: dev=11.0083, train=10.8731, epoch=18
Gal 0.0: dev=7.96906, train=4.4113, epoch=3
Gal 0.2: dev=7.33492, train=4.59449, epoch=3
Gal 0.5: dev=7.12243, train=4.5729, epoch=4
Gal 0.8: dev=7.30811, train=3.90383, epoch=12

** Run 2
Standard 0.0: dev=8.56846, train=6.90802, epoch=2
Standard 0.2: dev=7.95257, train=4.71503, epoch=4
Standard 0.5: dev=8.37596, train=6.35831, epoch=5
Standard 0.8: dev=11.102, train=12.2691, epoch=9
Gal 0.0: dev=8.19088, train=4.74689, epoch=3
Gal 0.2: dev=7.21125, train=4.59289, epoch=3
Gal 0.5: dev=7.02573, train=4.58412, epoch=4
Gal 0.8: dev=7.49874, train=4.75404, epoch=8

@@ -80,8 +82,44 @@ void LSTMBuilder::new_graph_impl(ComputationGraph& cg) {
// layout: 0..layers = c
// layers+1..2*layers = h
void LSTMBuilder::start_new_sequence_impl(const vector<Expression>& hinit) {
if (dropout_rate) { gal_dropout = true; } else { gal_dropout = false; }
Copy link

@vene vene Nov 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% clear on this but I think this interacts counterintuitively with the inverted dropout semantics.

My experimental observation is that if I do rnn.set_dropout(0.1) my training loss doesn't go down at all (I assume because this means 10% retention rate) but if I do rnn.set_dropout(0.0) it converges (dropout is completely off), I think because this if doesn't trigger.

I imagine this won't be an issue when this is merged, if the old behaviour is not kept, but I figured this comment might be useful to others trying to use the code in this PR.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 28, 2016

Ok. Unless there are objections, I will proceed with an update that:

  • drops the previous version of dropout
  • does dropout+scaling at train time (and nothing at test time)

@neubig
Copy link
Contributor

neubig commented Nov 29, 2016

Yes, that'd be great. And also the trivial reversing of the polarity of the dropout rate to make sure it's not the retention rate.

* Remove previous dropout behavior.
* set_dropout now indicates dropout_rate, not retention rate.
* weights are scaled when dropout is applied, so no scaling is needed at test time (just set dropout_rate = 0.0)
@yoavg
Copy link
Contributor Author

yoavg commented Nov 30, 2016

@neubig LSTMBuilder now has Gal Dropout with scaling at train time and nothing at test time. I did some minimal testing of this locally, but would be good if you could re-run your prev benchmark to see if we obtain the same results.

Something that bothers me on the scaling-at-train-time behavior: perhaps we should not scale according to the bernoulli parameter, but rather according to the number of 1s that actually appear in the bernoulli vector? this seems more correct.

@neubig
Copy link
Contributor

neubig commented Nov 30, 2016

OK, tests are started.

@neubig
Copy link
Contributor

neubig commented Nov 30, 2016

@yoavg Also, perhaps instead of calculating random_bernoulli() and then doing scaling, maybe we can just use the dropout() expression? I think this is basically the same thing, so I'll do another test doing this. Regarding how to scale, I agree that the actual number of 1s might be better, but this complicates the code a bit by requiring two passes (one to calculate the random variables and count them, then another to do the scaling), and in the limit of vector size they will be the same, so I'm not sure whether testing this extensively is worth the effort.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 30, 2016

@neubig Wait! there's a stupid bug. let me fix and push.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 30, 2016

The problem with using the dropout expression is that the same mask should be used for the entire sequence.

Re the scaling according to number of ones, let me try and implement this also and we can compare what works better.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 30, 2016

@neubig (fixed the stupid scaling bug)

@neubig
Copy link
Contributor

neubig commented Nov 30, 2016

@yoavg OK, rerunning this. Also, I added the ability to scale random bernoulli variables: 71fc893

This allows us to do scaling in a single operation without adding the extra multiplication. I'm testing a version of this commit with this change as well, and will contribute it if it's doing the same thing.

@yoavg
Copy link
Contributor Author

yoavg commented Nov 30, 2016

@neubig great. I also added a version with the scaling according to the number of 1s under the bernoulli_scaling branch.

@neubig
Copy link
Contributor

neubig commented Dec 2, 2016

Hmm, so my tests are almost finished, but inverted dropout doesn't seem to be working as well. I need to look into this a little more, and also test out the idea of using the actual number of 1s in the bernoulli samples.

*** Run 1
Reg. Dropout 0.0: dev=7.96906, train=4.4113, epoch=3, rate=0.001
Reg. Dropout 0.2: dev=7.33492, train=4.59449, epoch=3, rate=0.001
Reg. Dropout 0.5: dev=7.12243, train=4.5729, epoch=4, rate=0.001
Reg. Dropout 0.8: dev=7.30811, train=3.90383, epoch=12, rate=0.001
Inv. Dropout 0.0: dev=7.57288, train=4.40813, epoch=3, rate=0.001
Inv. Dropout 0.2: dev=7.64628, train=4.9021, epoch=3, rate=0.001
Inv. Dropout 0.5: dev=7.68190, train=5.27479, epoch=4, rate=0.001
Inv. Dropout 0.8: dev=8.83439, train=5.57314, epoch=10, rate=0.001

*** Run 2
Reg. Dropout 0.0: dev=8.19088, train=4.74689, epoch=3, rate=0.001
Reg. Dropout 0.2: dev=7.21125, train=4.59289, epoch=3, rate=0.001
Reg. Dropout 0.5: dev=7.02573, train=4.58412, epoch=4, rate=0.001
Reg. Dropout 0.8: dev=7.36944, train=4.47184, epoch=9, rate=0.001
Inv. Dropout 0.0: dev=8.56846, train=6.90802, epoch=2, rate=0.001
Inv. Dropout 0.2: dev=7.97863, train=4.98161, epoch=3, rate=0.001
Inv. Dropout 0.5: dev=7.72465, train=4.67207, epoch=5, rate=0.001
Inv. Dropout 0.8: dev=9.47496, train=7.77828, epoch=5, rate=0.001 (Not Converged)

@yoavg
Copy link
Contributor Author

yoavg commented Dec 2, 2016

the bernoulli_scaling branch has implementation of using the actual number of 1s in the bernoulli samples.

@neubig
Copy link
Contributor

neubig commented Dec 5, 2016

So I ran more experiments with the bernoulli_scaling branch and there's basically no perceivable difference with the inverted dropout in this branch. So we're still stuck with inverted dropout being significantly worse. I'm going to try to go back work through the equations and see if I can find an explanation. (Anyone else watching, it'd be great if you could do so as well...)

@pmichel31415
Copy link
Collaborator

Sorry to chime in, I have three things to say about the inverted/normal droput thing :

  • It doesn't make a difference for simple feed forward nets, see these results obtained with my mnist example from Mnist example #217 :

mnist_drop_vs_inv_drop

(Accuracy on the test set over 50 epochs for a 3-layers feed forward net with dropout on the first to layers (p=0.2))

  • I discussed it with @neubig but maybe the problem with inverted dropout in RNNs is that the 1/(1-p) factor accumulates in the gradient, so over t timesteps the gradient is 1/(1-p)^t times the gradient for standard dropout.

  • Also I double checked and for reference Keras and Tensorflow use inverted dropout as their standard dropout so it shouldn't make that big of a diffenrence

@yoavg
Copy link
Contributor Author

yoavg commented Dec 15, 2016

Interesting! Thanks.

Maybe the inverted dropout is particularly bad for Gal's form of the dropout, because it is on the same nodes for the entire sequence?

Anyhow, I will try to add Gal's version (both inverted and regular) to the Vanilla LSTM, and let's try the benchmarks again..

@neubig neubig merged commit 5a08a09 into master Jan 23, 2017
@neubig neubig deleted the gal_dropout branch February 7, 2017 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants