1

I tried to write my own estimator model_fn() for a GCP ML Engine package. I decoded a sequence of outputs using embedding_rnn_decoder as shown below:

outputs, state = tf.contrib.legacy_seq2seq.embedding_rnn_decoder(
    decoder_inputs = decoder_inputs,
    initial_state = curr_layer,
    cell = tf.contrib.rnn.GRUCell(hidden_units),
    num_symbols = n_classes, 
    embedding_size = embedding_dims,
    feed_previous = False)

I know that outputs is "A list of the same length as decoder_inputs of 2D Tensors" but I am wondering how I can use this list to calculate the loss function for the entire sequence?

I know that if I grab outputs[0] (ie. grab only the first sequence output) then I could loss by following:

logits = tf.layers.dense(
    outputs[0],
    n_classes)
loss = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=labels)

Is it appropriate to generate a loss value for each out the items in output and then pass these all to tf.reduce_mean? This feels inefficient, especially for long sequences -- are there any other ways to calculate the softmax at each step of the sequence that would be more efficient?

reese0106
  • 2,011
  • 2
  • 16
  • 46
  • Is this information useful? https://www.tensorflow.org/tutorials/seq2seq#sampled_softmax_and_output_projection – rhaertel80 Aug 16 '17 at 06:46
  • No, that does not help. My issue is not solved by sampled softmax. My issue is that I need to perform multiple iterations of softmax (on each output of the RNN). I would need to calculate the loss (or sampled softmax loss) at each step in the sequence and that is where I am not sure how to do this properly. – reese0106 Aug 16 '17 at 12:05

2 Answers2

1

I think you're looking for sequence_loss (currently in contrib/ a type of incubation).

rhaertel80
  • 8,254
  • 1
  • 31
  • 47
0

It looks like the solution to my problem is to use sequence_loss_by_example

reese0106
  • 2,011
  • 2
  • 16
  • 46
  • That is in a package called "legacy". I've added a link to the non-legacy version, but note that it, too, is in contrib/ (a type of incubation). – rhaertel80 Aug 17 '17 at 04:43
  • I was using the "legacy" package to generate my outputs, so the legacy package seems like hte more appropriate metric to rely on for this question – reese0106 Aug 20 '17 at 15:30
  • Staying within the same package is probably a good idea. Consider moving off of legacy at some point. I suspect they don't update it with bug fixes or performance enhancements. – rhaertel80 Aug 22 '17 at 05:37