i am learning tensorflow/keras for image classification and i feel like i'm missing a critical part of the theory.
the task that i am currently working on deals with using a pretrained model (Resnet50 in this case) to do classification on a small data set, with limited training time.
the data set is 1600 150 x 150 color photos of fruit, that fall into 12 classes. i am using a generator for the images:
datagen = ImageDataGenerator(
validation_split=0.25,
rescale=1/255,
horizontal_flip=True,
vertical_flip=True,
width_shift_range=0.2,
height_shift_range=0.2,
rotation_range=90)
train_datagen_flow = datagen.flow_from_directory(
'/datasets/fruits_small/',
target_size=(150, 150),
batch_size=32,
class_mode='sparse',
subset='training',
seed=12345)
val_datagen_flow = datagen.flow_from_directory(
'/datasets/fruits_small/',
target_size=(150, 150),
batch_size=32,
class_mode='sparse',
subset='validation',
seed=12345)
features, target = next(train_datagen_flow)
here is the layers i am using:
backbone = ResNet50(input_shape=(150, 150, 3),weights='imagenet', include_top=False) backbone.trainable = False
model = Sequential()
optimizer = Adam(lr=0.001)
model.add(backbone)
model.add(GlobalMaxPooling2D())
model.add(Dense(2048,activation='relu'))
model.add(BatchNormalization())
model.add(Dense(512,activation = 'relu'))
model.add(BatchNormalization())
model.add(Dense(12, activation='softmax'))
model.compile(optimizer = optimizer, loss='sparse_categorical_crossentropy',metrics=['acc'])
Now, this is my first attempt at using globalmax and resnet50, and i am experiencing MASSIVE overfitting, because, i presume, the small data set.
i've done some reading on the subject, and, i've tried a few normalization efforts with limited success.
in conversation with my tutor, he suggested that i think more critically about the output of the resnet model when selecting my parameters for my dense layers.
this comment made me realize that i have basically been arbitrarily selecting the filters for the dense layers, but it sounds like i should understand something related to the output of the previous layer when building a new one, and i'm not sure what, but i feel like i am missing something critical.
this is what my current layer summary looks like:
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Model) (None, 5, 5, 2048) 23587712
_________________________________________________________________
global_max_pooling2d_3 (Glob (None, 2048) 0
_________________________________________________________________
dense_7 (Dense) (None, 2048) 4196352
_________________________________________________________________
batch_normalization_2 (Batch (None, 2048) 8192
_________________________________________________________________
dense_8 (Dense) (None, 512) 1049088
_________________________________________________________________
batch_normalization_3 (Batch (None, 512) 2048
_________________________________________________________________
dense_9 (Dense) (None, 12) 6156
=================================================================
Total params: 28,849,548
Trainable params: 5,256,716
Non-trainable params: 23,592,832
here is what my current output looks like:
Epoch 1/3
40/40 [==============================] - 363s 9s/step - loss: 0.5553 - acc: 0.8373 - val_loss: 3.8422 - val_acc: 0.1295
Epoch 2/3
40/40 [==============================] - 354s 9s/step - loss: 0.1621 - acc: 0.9423 - val_loss: 6.3961 - val_acc: 0.1295
Epoch 3/3
40/40 [==============================] - 357s 9s/step - loss: 0.1028 - acc: 0.9716 - val_loss: 4.8895 - val_acc: 0.1295
so i've read about freezing the resnet layers for training to help with overfitting, and regularization (which is what i am attempting with the batch normalization? - though this seems to be considered questionable to a lot of people..) i've also tried using dropout for the first and second dense layers as well as by increasing the data set size with augmentation (i've got rotations and such)
Any input would be appreciated!