0

Continuing from this question and the discussion here - I am trying to use the Dataset API to take a dataset of variable length tensors and cut them into slices (segments) of equal length. Something like:

Dataset = tf.contrib.data.Dataset
segment_len = 6
batch_size = 16

with tf.Graph().as_default() as g:
    # get the tfrecords dataset
    dataset = tf.contrib.data.TFRecordDataset(filenames).map(
        partial(record_type.parse_single_example, graph=g)).batch(batch_size)
    # zip it with the number of segments we need to slice each tensor
    dataset2 = Dataset.zip((dataset, Dataset.from_tensor_slices(
        tf.constant(num_segments, dtype=tf.int64))))
    it2 = dataset2.make_initializable_iterator()
    def _dataset_generator():
        with g.as_default():
            while True:
                try:
                    (im, length), count = sess.run(it2.get_next())
                    dataset3 = Dataset.zip((
                        # repeat each tensor then use map to take a stridded slice
                        Dataset.from_tensors((im, length)).repeat(count),
                        Dataset.range(count))).map(lambda x, c: (
                            x[0][:, c: c + segment_len],
                            x[0][:, c + 1: (c + 1) + segment_len],
                    ))
                    it = dataset3.make_initializable_iterator()
                    it_init = it.initializer
                    try:
                        yield it_init
                        while True:
                            yield sess.run(it.get_next())
                    except tf.errors.OutOfRangeError:
                        continue
                except tf.errors.OutOfRangeError:
                    return
    # Dataset.from_generator need tensorflow > 1.3 !
    das_dataset = Dataset.from_generator(
        _dataset_generator,
        (tf.float32, tf.float32),
        # (tf.TensorShape([]), tf.TensorShape([]))
    )
    das_dataset_it = das_dataset.make_one_shot_iterator()


with tf.Session(graph=g) as sess:
    while True:
        print(sess.run(it2.initializer))
        print(sess.run(das_dataset_it.get_next()))

Of course I do not want to pass the session in the generator but this should be workarounded by the trick given in the link (create a dummy dataset and map the iterator of the other). The code above fails with the biblical:

tensorflow.python.framework.errors_impl.InvalidArgumentError: TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.framework.ops.Operation'>.
         [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]]
         [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[<unknown>, <unknown>], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

which is I guess because I try to yield the initializer of the iterator but my question is basically if I can achieve at all what I am trying using the dataset API.

Mr_and_Mrs_D
  • 32,208
  • 39
  • 178
  • 361

1 Answers1

1

The easiest way to build a Dataset from a nested Dataset is to use the Dataset.flat_map() transformation. This transformation applies a function to each element of the input dataset (dataset2 in your example), that function returns a nested Dataset (most likely dataset3 in your example), and then the transformation flattens all the nested datasets into a single Dataset.

dataset2 = ...  # As above.

def get_slices(im_and_length, count):
  im, length = im_and_length
  # Repeat each tensor then use map to take a strided slice.
  return Dataset.zip((
      Dataset.from_tensors((im, length)).repeat(count),
      Dataset.range(count))).map(lambda x, c: (
          x[0][:, c + segment_len: (c + 1) + segment_len],
          x[0][:, c + 1 + segment_len: (c + 2) + segment_len],
  ))

das_dataset = dataset2.flat_map(get_slices)
mrry
  • 125,488
  • 26
  • 399
  • 400
  • Excellent thanks - it hadn't occurred to me flat_map is the tool for the job – Mr_and_Mrs_D Oct 07 '17 at 01:29
  • FYI this maybe doesn't play well with MonitoredTrainingSession - the iterator is sometimes unexpectedly advanced (as it is bound to summaries like model.cost ?) - or I may be quite wrong and it's my fault. Will have to investigate more, but meanwhile, since MonitoredTrainingSession and dataset integration is discussed over at github, I just note so you keep that in mind too - namely, that we must at least warn people to be careful advancing the iterator in operations hidden inside the MonitoredTrainingSession. – Mr_and_Mrs_D Oct 11 '17 at 14:35