3

I will first summarize what I think I understood about cuDNN 5.1 rnn functions:

Tensor dimensions

x = [seq_length, batch_size, vocab_size] # input
y = [seq_length, batch_size, hiddenSize] # output

dx = [seq_length, batch_size, vocab_size] # input gradient
dy = [seq_length, batch_size, hiddenSize] # output gradient

hx = [num_layer, batch_size, hiddenSize] # input hidden state
hy = [num_layer, batch_size, hiddenSize] # output hidden state
cx = [num_layer, batch_size, hiddenSize] # input cell state
cy = [num_layer, batch_size, hiddenSize] # output cell state

dhx = [num_layer, batch_size, hiddenSize] # input hidden state gradient
dhy = [num_layer, batch_size, hiddenSize] # output hidden state gradient
dcx = [num_layer, batch_size, hiddenSize] # input cell state gradient
dcy = [num_layer, batch_size, hiddenSize] # output cell state gradient

w = [param size] # parameters (weights & bias)
dw = [param size] # parameters gradients

cudnnRNNForwardTraining / cudnnRNNForwardInference

input: x, hx, cx, w
output: y, hy, cy

cudnnRNNBackwardData

input: y, dy, dhy, dcy, w, hx, cx
output: dx, dhx, dcx

cudnnRNNBackwardWeights

input: x, hx, y, dw
output: dw

Questions:

  1. Is the following training workflow for multi-layer RNN (num_layer > 1) correct?
  1. init hx,cx,dhy,dcy to NULL
  2. init w: (weights:small random values, bias: 1)
  3. forward
  4. backward data
  5. backward weights
  6. update weights: w += dw
  7. dw = 0
  8. goto 3.
  1. Do you confirm cuDNN already implements stacked rnn when num_layer > 1? (no need to call num_layer times forward/backward methods)
  2. Should I re-inject hidden state & cell state into the network at next batch?
  3. The output in lstm/gru formulas is hy. Should I use hy as output or y?

Same question posted here (I will synchronize answers)

  • Note that the inner-most dimension of the input tensor usually doesn't have anything to do with the vocabulary size (unless you're using one-hot tensors or something); usually, it's the size of the embedding. – Edward Z. Yang Jan 31 '18 at 20:58

0 Answers0