2

The tensorflow config dropout wrapper has three different dropout probabilities that can be set: input_keep_prob, output_keep_prob, state_keep_prob.

I want to use variational dropout for my LSTM units, by setting the variational_recurrent argument to true. However, I don't know which of the three dropout probabilities I have to use for variational dropout to function correctly.

Can someone provide help?

Lemon
  • 1,394
  • 3
  • 14
  • 24

1 Answers1

5

According to this paper https://arxiv.org/abs/1512.05287 that is used for implementation of the variational_recurrent dropouts, you can think about as follows,

  • input_keep_prob - probability that dropping out input connections.

  • output_keep_prob - probability that dropping out output connections.

  • state_keep_prob - Probability that droping out recurrent connections.

See the diagram below,

enter image description here

If you set the variational_recurrent to be true you will get an RNN that's similar to the model in right and otherwise in left.

The basic differences in above two models are,

  • Variational RNN repeats the same dropout mask at each time step for both inputs, outputs, and recurrent layers (drop the same network units at each time step).

  • Native RNN uses different dropout masks at each time step for the inputs and outputs alone (no dropout is used with the recurrent connections since the use of different masks with these connections leads to deteriorated performance).

In the above diagram, coloured connections represent the dropped-out connections, with different colours corresponding to different dropout masks. Dashed lines correspond to standard connections with no dropout.

Therefore, if you use a variational RNN you can set all three probability parameters according to your requirement.

Hope this helps.

Nipun Wijerathne
  • 1,839
  • 11
  • 13
  • Thanks for your answer. There are still some points that are unclear to me: 1. what is `state_dropout` in LSTM units? Which part of the LSTM unit does it refer to? 2. when stacking multiple LSTM units (or RNN units in general) wouldn't using all three kinds of dropout probability be too much? When the first cell uses all three kinds of dropout, and the second cell as well, we would have "output dropout" and "input dropout" right after each other. In this case the input dropout would drop even more cells from the already dropped outputs of the previous cell – Lemon Nov 22 '17 at 07:12
  • 3. in the paper, "variational dropout" refers to dropping input, recurrent and output connections at the same time in a specific manner. However, in more recent papers it sometimes seems as if "variational dropout" simply means that the same mask is used at each time step, independent of where dropout is applied. So what does "variational dropout" exactly mean? Simply that the same mask is repeated at each time step? Or in addition that it's applied to inputs AND recurrent connections AND outputs? – Lemon Nov 22 '17 at 07:17
  • Answers to your question, 1) state_dropout - dropping out recurrent connections (horizontal arrows in the diagram). 2) Depends on your application, data etc. For some applications it will work and maybe for some application, it won't. 3) variational dropout means drop the same network units at each time step and can be applied to input, output and recurrent connections. – Nipun Wijerathne Nov 22 '17 at 10:04
  • Thanks, but you misunderstood question 1). LSTM units have different gates. My question refers to which gates and parts of the LSTM unit recurrent dropout will affect. Also, could you please specify your answer for question 2)? When stacking two LSTM units right after each other, using output dropout in the first LSTM should be equivalent to using input dropout in the second LSTM. How is that dependent on the application? And regarding 3) where does this definition come from? In the paper it seems to be defined differently – Lemon Nov 22 '17 at 10:14
  • 1) Yes, gates are used o calculate the hidden units of the RNN and thus those recurrent connections (horizontal lines). Therefore, dropouts apply to those connections. 2) It's hard to mention an application as such. Maybe your argument is correct. I meant here for some models that will work depends on the data. 3) Definitions is same as in the paper (same words) – Nipun Wijerathne Nov 23 '17 at 13:10
  • There is something I am confused about. you say that no dropout is used on recurrent connections because of the bad performance. However, it is obvious that it is being used since the horizontal arrows on the diagram show so. Am I correct? You can still set `state_keep_prob` argument using Native RNN. – ARAT Feb 09 '19 at 16:45