2

Suppose I have 3 tfrecord files, namely neg.tfrecord, pos1.tfrecord, pos2.tfrecord.

My batch size is 500, including 300 neg data, 100 pos1 data, and 100 pos2 data. How can I get the desired TFRecordDataset?

I will use this TFRecordDataset object in keras.fit() (Eager Execution).

My tensorflow's version is 1.13.1. I find the API in tf.data.Dataset, such as interleave, concatenate, zip, but it seems that I can't solve my problem.

Before, I tried to get the iterator for each dataset, and then manually concat after getting the data, but it was inefficient and the GPU utilization was not high.

And in this question, I use interleave below:

tfrecord_files = ['neg.tfrecord', 'pos1.tfrecord', 'pos2.tfrecord']
dataset = tf.data.Dataset.from_tensor_slices(tfrecord_files)
def _parse(x):
    x = tf.data.TFRecordDataset(x)
    return x
dataset = dataset.interleave(_parse, cycle_length=4, block_length=1)
dataset = dataset.apply(tf.data.experimental.map_and_batch(_parse_image_function, 500))

and I got this batch:

neg pos1 pos2 neg pos1 pos2 ...............

But what I want is this:

neg neg neg pos1 pos2 neg neg neg pos1 pos2 .................

What I should do?

Looking forward to answering.

Gary
  • 823
  • 1
  • 8
  • 14

1 Answers1

1

I reproduced something like what you said using string data:

import tensorflow as tf

def string_data(s):
    return tf.sparse.to_dense(tf.strings.split([s]), default_value='')[0]

data = [' '.join(['neg'] * 30), ' '.join(['pos1'] * 10), ' '.join(['pos2'] * 10)]
step_sizes = tf.constant([3, 1, 1], dtype=tf.int64)
ds = (tf.data.Dataset.from_tensor_slices((data, step_sizes))
      .interleave(lambda d, s: (tf.data.Dataset.from_tensor_slices(string_data(d))
                                .batch(s)),
                  cycle_length=len(data))
      .flat_map(tf.data.Dataset.from_tensor_slices))
iter = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter).decode(), end=', ')
        except tf.errors.OutOfRangeError: break
    print()

Output:

neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, 

In the real use case you would replace data with the list of file names and tf.data.Dataset.from_tensor_slices(string_data(d)) with tf.data.TFRecordDataset(d), but otherwise it should work similarly.

EDIT: I just realised that you actually wanted a batch of all elements ordered in that way, not just one element at a time, so I suppose you would have to add another batch call at the end.

jdehesa
  • 58,456
  • 7
  • 77
  • 121