I have looked into other issues on this problem but could not find the exact answer, so trying from scratch:
The problem
I have multiple .npy files (X_train files) each an array of shape (n, 99, 2) - only the first dimension differs, while the remaining two are the same. Based on the name of the .npy file I can also get corresponding labels (y_train files).
Every such couple of files can be loaded into memory easily (so do multiple files), but not all of them at once.
I built a generator that goes through the file list and aggregates a given number of files for the training batch:
def tf_data_generator(filelist, directory = [], batch_size = 5):
i = 0
x_t = os.listdir(directory[0])
y_t = os.listdir(directory[1])
while True:
file_chunk = filelist[i*batch_size:(i+1)*batch_size]
X_a = []
Y_a = []
for fname in file_chunk:
x_info = np.load(path_x_tr+fname)
y_info = np.load(path_y_tr+fname)
X_a.append(x_info)
Y_a.append(y_info)
X_a = np.concatenate(X_a)
Y_a = np.concatenate(Y_a)
yield X_a, Y_a
i = i + 1
In practice (on CPU) it works fine, however it crashes if I am trying to use a GPU on CUDA, giving Failed to call ThenRnnForward with model config:
error (see: link )
So I am trying to find another approach and use tf.data API for data generation. However, I am stuck:
def parse_file(name):
x = np.load('./data/x_train_m/'+name)
y = np.load('./data/y_train_m/'+name)
train_dataset = tf.data.Dataset.from_tensor_slices((test1, test2))
return train_dataset
train_dataset = parse_file('example1.npy')
train_dataset = train_dataset.shuffle(100).batch(64)
model = wtte_rnn()
model.summary()
K.set_value(model.optimizer.lr, 0.01)
model.fit(train_dataset,
epochs=10)
This works well, however, I could not find a way to:
- mix multiple files (up to a certain number, let's say five)
- traverse through the whole list of files
I have read up on flat_map and interleave, however, I haven't been able to go any further and any attempt at using those was unsuccessful. How can I make a similar generator as in the upper portion of the code, ,but using tf.data API?