2

In the official documentation of tf.nn.raw_rnn we have emit structure as the third output of loop_fn when the loop_fn is run for the first time.

Later on the emit_structure is used to copy tf.zeros_like(emit_structure) to the minibatch entries that are finished by emit = tf.where(finished, tf.zeros_like(emit_structure), emit).

my lack of understanding or lousy documentation on google's part is: emit structure is None so tf.where(finished, tf.zeros_like(emit_structure), emit) is going to throw a ValueError as tf.zeros_like(None) does so. Can somebody please fill in what I am missing here?

Maxim
  • 52,561
  • 27
  • 155
  • 209
figs_and_nuts
  • 4,870
  • 2
  • 31
  • 56

1 Answers1

1

Yep, the doc is rather confusing in this place. If you look at the internals of tf.nn.raw_rnn, the key term there is "in pseudo-code", so the example in the doc isn't accurate.

The exact source code looks like this (may differ depending on your tensorflow version):

if emit_structure is not None:
  flat_emit_structure = nest.flatten(emit_structure)
  flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                    array_ops.shape(emit) for emit in flat_emit_structure]
  flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
  emit_structure = cell.output_size
  flat_emit_size = nest.flatten(emit_structure)
  flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

So it handles the case when emit_structure is None and simply takes the value cell.output_size. That's why nothing really breaks.

Maxim
  • 52,561
  • 27
  • 155
  • 209