Suppose I have 3 tfrecord files, namely neg.tfrecord
, pos1.tfrecord
, pos2.tfrecord
.
My batch size is 500, including 300 neg data, 100 pos1 data, and 100 pos2 data. How can I get the desired TFRecordDataset?
I will use this TFRecordDataset object in keras.fit() (Eager Execution).
My tensorflow's version is 1.13.1. I find the API in tf.data.Dataset, such as interleave
, concatenate
, zip
, but it seems that I can't solve my problem.
Before, I tried to get the iterator for each dataset, and then manually concat after getting the data, but it was inefficient and the GPU utilization was not high.
And in this question, I use interleave
below:
tfrecord_files = ['neg.tfrecord', 'pos1.tfrecord', 'pos2.tfrecord']
dataset = tf.data.Dataset.from_tensor_slices(tfrecord_files)
def _parse(x):
x = tf.data.TFRecordDataset(x)
return x
dataset = dataset.interleave(_parse, cycle_length=4, block_length=1)
dataset = dataset.apply(tf.data.experimental.map_and_batch(_parse_image_function, 500))
and I got this batch:
neg pos1 pos2 neg pos1 pos2 ...............
But what I want is this:
neg neg neg pos1 pos2 neg neg neg pos1 pos2 .................
What I should do?
Looking forward to answering.