I have a fully working seq2seq attention model with beam search and it does give improved results. But it takes > 1min for inferencing (batch-size 1024) with k=5 (k is my hypotheses) because none of it is parallelised. Everything happens 1 sample at a time.
Task (simplified)
Goal is sentence translation, 15 words Lang A to 15 words Lang B.
- Encoder is a RNN that takes in 15 word sentence and encodes a representation of it, gives out a [timestep, 512] matrix along with final hidden state.
- Decoder is another RNN, takes encoder hidden state as initial state, uses [timestep, 512] matrix for attention and outputs translated words[batches] one timestep at a time. Naturally, there is some form of parallelization till this point.
- During inference stage, beam search is used. At each timestep of the decoder, rather than taking the predicted word with highest prob, I take k best words. And provide k words as input to the next timestep so that it can predict the next word in the sentence (rest of the algorithm is given below). Algorithm makes decoding less greedy anticipating results with higher total probability in succeeding timesteps.
for each element in the test-set
calculate initial k (decoder-encoder step)
for range(timesteps-1)
for each prev k
get hidden state
obtain its best k
save hidden state
find new k from k*k possible ones
##update hypotheses based on new found k
for element in k
copy hidden state
change hypotheses if necessary
append new k to hypotheses
There are 6 tensors and 2 lists to keep track and handle state changes. Is there any room for speedup or parallelisation here? perhaps each k can go through enncode-decode simultaneously? Any help is much appreciated.