I am currently trying to train a convolutional neural network CNN
using Keras and the Google Colab
GPU.
I found this article that discussed the option to increase the training time that is needed to train the model. Since the current training on the GPU is very slow I tried to implement the method from the article. I have the following code:
sgd = optimizers.SGD(lr=0.02)
model.compile(optimizer=sgd,loss='categorical_crossentropy',metrics=['accuracy'])
def create_train_subsets():
X_train =[]
y_train = []
for i in range(80):
cat = i+1
path = 'train_set/by_cat/{}'.format(cat)
for img in os.listdir(path):
actual_image = Image.open(("train_set/by_cat/{}/{}".format(cat,img)))
X_train.append(actual_image)
y_train.append(cat)
return X_train, y_train
# This address identifies the TPU we'll use when configuring TensorFlow.
x_train, y_train = create_train_subsets()
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
tf.logging.set_verbosity(tf.logging.INFO)
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))
history = tpu_model.fit(x_train, y_train,
epochs=20,
batch_size=128 * 8,
validation_split=0.2)
tpu_model.save_weights('./tpu_model.h5', overwrite=True)
# tpu_model.evaluate(x_test, y_test, batch_size=128 * 8)
This code however gives back the following error:
InvalidArgumentError: No OpKernel was registered to support Op 'ConfigureDistributedTPU' used by node ConfigureDistributedTPU (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1748) with these attrs: [tpu_embedding_config="", is_global_init=false, embedding_config=""]
Registered devices: [CPU, XLA_CPU]
Registered kernels:
<no registered kernels>
[[ConfigureDistributedTPU]]
I did an extensive search online but I can't seem to find any indication on what it means. Also, I am not understanding the process enough to figure out the exact meaning of the error myself.
Therefore, is there anybody out there that can help me understand what is wrong and maybe also knows a solution on how to solve this.
Thank you in advance!