Recently, I have been trying to implement RNNLM based on this article. There is an implementation with some LSTM factorization tricks, but similar to the original implementation by the author.
Preambula
1) The dataset is split into files and then lines of the file are being shuffled at the train time and being feed sequentially at test time. (link):
# deterministic at test time, non deterministic at train time
if not self._deterministic:
random.shuffle(lines)
2) The batches is formed continuously
The * symbol represents start \ ending of the sentence. Each matrix represents one bach. The code link. So:
If sentence is longer then num_steps, it continues on the next batch at the same line.
If sentence is shorter, the batch line is being filled with another sentences.
3) They calculate a batch mean loss. num_steps - the memory of the LSTM. Code.
# loss_shape = [batch_size * num_steps]
# 1D tensor, reshaped 2d tensor with dims of [batch_size, num_steps]
loss = tf.reduce_mean(loss)
4) The LSTM cell is being update after each training iteration and being zeroed out at evaluation time.
They declare it as local variables declaration.
And then it's being updated at the train time. And zeroed out at eval time.
5) On the eval time the authors calculate the perplexity this way (the link):
for i, (x, y) in enumerate(data_iterator):
# get a batch
loss = sess.run(model.loss, {model.x: x, model.y: y})
loss_nom += loss
loss_den += 1
loss = loss_nom / loss_den
sys.stdout.write("%d: %.3f (%.3f) ... " % (i, loss, np.exp(loss)))
sys.stdout.flush()
sys.stdout.write("\n")
It means they measure batch-average perplexity.
That being said, I have 2 main questions.
Questions
- Considering preamble 1), 2) and 4).
Why batches are being formed that way?
The LSTM cell is not being zeroed out after each sentence, so it keeps memory of the previous sentence.
In the example at the top when neural net is processing batch №1 the line №2 for the word "Half" it remembers the context of word Music and start\end tokens. It could make sense if the sentences were not shuffled and it was the real text, but they are shuffled and not connected to each other.
I implemented both methods and infinite batches gave much better performance.
- Considering preamble 3) and 5).
Why do we estimate batch-average perplexity?
Taking in consideration first question, it's not clear to me that when we measure perplexity this way, we can really estimate how good our model is. But sentence-average perplexity seems more efficient.
If there is a flaw in my logic, I'd be grateful if you point that out.