1

I am using a model for sequence prediction starting from a latent representation of an encoded input, which forms the initial state of the decoder. It could be a feature vector from an image (for captioning) or the result of a seq2seq encoder.

My model is trained with teacher forcing and this goes quite fast. However inference is brutally slow because I do a stepwise sequence extension in the form of (pseudocode)

sequence_terminated = False
sequence = np.array((0, output_features))
while not sequence_terminated:
   seq_output, seq_states = model.predict(seq_input)
   next_input, sequence_terminated = f(seq_output)
   sequence = np.concatenate(sequence, seq_output)

I have done a lot of optimization at this stage so I can predict sequences for hundreds of queries in parallel, but 1) running on CPU it scales linear at >32 or so sequences, and 2) running on GPU is actually slower than on CPU, presumedly because data has to be moved back and forth after every step and there is no profit off the GPU speed.

I'm additionally using a non-greedy sequence search that isn't Beam Search but can backtrack in a way similar to A*, more or less like this (pseudocode):

sequence_terminated = False
sequence = np.array((0, output_features))
states = np.array((0, state_size))
pos = []
from_position = 0
remaining_sequences = 5
while remaining_sequences > 0:
   seq_output, seq_states = model.predict(seq_input)
   sequence = np.concatenate(sequence, seq_output)
   states = np.concatenate(states, seq_states)
   pos.append(from_position)
   # Based on the outputs until now, find what sequence stub to continue:
   next_input, next_states, from_position, terminated = f(sequence, states)
   if terminated:
      remaining_sequences = remaining_sequences - 1

which gives top-n sequences backtracking from the last predicted position. Again this is more or less optimized on the CPU side of things for parallel prediction.

I think to get faster I need to run prediction completely on GPU without moving data back. But I don't get how to write this in TensorFlow. There is tfa.seq2seq (former tf.contrib.seq2seq) which has an infrastructure for decoders that presumably run efficiently as models, but I cannot find much documentation.

Note that my model (Keras functional API; it can also be used with model() instead of model.predict() or I can wire the output tensors somewhere else) is not simply 3 LSTM layers but has some inline feature engineering that is stateful, so it needs to be done in the model. tfa.seq2seq.Decoder seems to expect a single cell to wrap itself around?

Questions: 1) Can I use the tfa.seq2seq decoder architecture for a blackbox model built and trained independently from the tfa.seq2seq architecture? If yes, where can I find info about that? 2) Are there any pointers on how to implement greedy and non-greedy sequence search directly on tensorflow, without falling back to python code like mine above? I understand I will probably have to give up my non-greedy approach and use just beam search, which will probably perform about the same.

meow
  • 925
  • 7
  • 22

0 Answers0