Suppose I have 3 tfrecord files, namely neg.tfrecord
, pos1.tfrecord
, pos2.tfrecord
.
I use
dataset = tf.data.TFRecordDataset(tfrecord_file)
this code creates 3 Dataset objects.
My batch size is 400, including 200 neg data, 100 pos1 data, and 100 pos2 data. How can I get the desired dataset?
I will use this dataset object in keras.fit() (Eager Execution).
My tensorflow's version is 1.13.1.
Before, I tried to get the iterator for each dataset, and then manually concat after getting the data, but it was inefficient and the GPU utilization was not high.