I'm building an android app that can identify different types of fruits. For the dataset, I'm using Fruits-360 Dataset with the white background replaced with images from Open Images Dataset using this python script to generalize the dataset to include fruits with different backgrounds. For the model, I'm using MobileNet(v4) with the following configuration:
do_fine_tuning = True
URL = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4'
mobilenet = hub.KerasLayer(handle=URL, input_shape=(IMG_SIZE, IMG_SIZE, 3),
trainable=do_fine_tuning)
model = tf.keras.Sequential([
mobilenet,
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(units=num_classes)
])
model.summary()
NUM_LAYERS = 10
if do_fine_tuning:
mobilenet.trainable = True
for layer in model.layers[-NUM_LAYERS:]:
layer.trainable = True
else:
mobilenet.trainable = False
if do_fine_tuning:
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.002, momentum=0.9),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
else:
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
epochs = 5
history = model.fit(tr_data_gen, steps_per_epoch=int(np.ceil(total_tr_img/float(BATCH_SIZE))),
epochs=epochs, validation_data=val_data_gen,
validation_steps=int(np.ceil(total_val_img/float(BATCH_SIZE))))
Python script used for prediction:
import tensorflow as tf
import numpy as np
import tensorflow.keras.preprocessing.image as image
import matplotlib.pyplot as plt
model = tf.keras.models.load_model(r'saved_model\better_modelv2.0')
class_names = ["Apple", "Apricot", "Banana", "Blueberry", "Cherry", "Fig", "Grapes", "Guava", "Kiwi", "Lemon",
"Lime", "Lychee", "Mango", "Melon", "Orange", "Papaya", "Pear", "Pineapple", "Raspberry",
"Strawberry", "Tomato", "Watermelon"]
img_path = r'test_images/banana-single.jpg'
img = tf.keras.utils.load_img(img_path, target_size=(224, 224))
input_img = image.img_to_array(img)
input_img = input_img/255
result = model.predict(input_img[np.newaxis, ...])
print(np.max(result))
predicted_class = np.argmax(result[0], axis=-1)
predicted_class_name = class_names[predicted_class]
print(predicted_class)
print(predicted_class_name)
plt.imshow(img)
plt.axis('off')
_ = plt.title("Prediction: " + predicted_class_name.title())
plt.show()
The model performed really well on the validation dataset with an accuracy of 0.999 but it is making wrong predictions on simple test images. What could be causing this issue? Is there any minimum number of epochs that I must set when using pre-trained models? Or did I mess up something in the code or dataset preparation phase?