I have built CNN model by using the principle of "Model Sublclassing" in Keras. Here is the class which represents my model:
class ConvNet(tf.keras.Model):
def __init__(self, data_format, classes):
super(ConvNet, self).__init__()
if data_format == "channels_first":
axis = 1
elif data_format == "channels_last":
axis = -1
self.conv_layer1 = tf.keras.layers.Conv2D(filters = 32, kernel_size = 3,strides = (1,1),
padding = "same",activation = "relu")
self.pool_layer1 = tf.keras.layers.MaxPooling2D(pool_size = (2,2), strides = (2,2))
self.conv_layer2 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3,strides = (1,1),
padding = "same",activation = "relu")
self.pool_layer2 = tf.keras.layers.MaxPooling2D(pool_size = (2,2), strides = (2,2))
self.conv_layer3 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 5,strides = (1,1),
padding = "same",activation = "relu")
self.pool_layer3 = tf.keras.layers.MaxPooling2D(pool_size = (2,2), strides = (1,1),
padding = "same")
self.flatten = tf.keras.layers.Flatten()
self.dense_layer1 = tf.keras.layers.Dense(units = 512, activation = "relu")
self.dense_layer2 = tf.keras.layers.Dense(units = classes, activation = "softmax")
def call(self, inputs, training = True):
output_tensor = self.conv_layer1(inputs)
output_tensor = self.pool_layer1(output_tensor)
output_tensor = self.conv_layer2(output_tensor)
output_tensor = self.pool_layer2(output_tensor)
output_tensor = self.conv_layer3(output_tensor)
output_tensor = self.pool_layer3(output_tensor)
output_tensor = self.flatten(output_tensor)
output_tensor = self.dense_layer1(output_tensor)
return self.dense_layer2(output_tensor)
I would like to know how to train it "eagerly", and by that I mean avoiding the use of compile
and fit
methods.
I am not sure how to exactly construct the training loop. I understand that I must perform tf.GradientTape.gradient()
function in order to calculate the gradients and then use optimizers.apply_gradients()
in order to update my model parameters.
What I do not understand is how can I make predictions with my model in order to get logits
and then use them to calculate the loss. If someone could help me with the idea of how to construct the training loop I would really appreciate it.