I'm trying to train a sequence to sequence model for machine translation using Keras on Google Colab TPU. I have a dataset which I can load in memory but I have to preprocess to it to feed it to the model. In particular I need to convert the target words to one hot vectors and with many examples I can't load the entire conversion in memory, so I need to make batches of data.
I'm using this function as a batch generator:
def generate_batch_bert(X_ids, X_masks, y, batch_size = 1024):
''' Generate a batch of data '''
while True:
for j in range(0, len(X_ids), batch_size):
# batch of encoder and decoder data
encoder_input_data_ids = X_ids[j:j+batch_size]
encoder_input_data_masks = X_masks[j:j+batch_size]
y_decoder = y[j:j+batch_size]
# decoder target and input for teacher forcing
decoder_input_data = y_decoder[:,:-1]
decoder_target_seq = y_decoder[:,1:]
# batch of decoder target data
decoder_target_data = to_categorical(decoder_target_seq, vocab_size_fr)
# keep only with the right amount of instances for training on TPU
if encoder_input_data_ids.shape[0] == batch_size:
yield([encoder_input_data_ids, encoder_input_data_masks, decoder_input_data], decoder_target_data)
The problem is that whenever I try to run the fit function as follows:
model.fit(x=generate_batch_bert(X_train_ids, X_train_masks, y_train, batch_size = batch_size),
steps_per_epoch = train_samples//batch_size,
epochs=epochs,
callbacks = callbacks,
validation_data = generate_batch_bert(X_val_ids, X_val_masks, y_val, batch_size = batch_size),
validation_steps = val_samples//batch_size)
I get the following error:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_util.py:445 make_tensor_proto
raise ValueError("None values not supported.")
ValueError: None values not supported.
Not sure what's wrong and how I can solve this problem.
EDIT
I tried loading less amount of data in memory so that the conversion to one hot encoding of the target words doesn't crash the kernel and it actually works. So there is obviously something wrong on how I generate batches.