-3

I'm working on an image classification model for a multiclass problem. I get the model up and running, but when I try to predict/test the model, it appears only to be able to recognise 1 of 4 image types (it's the same class no matter how I change the model). My dataset per class is pretty small, but I do use imagegenerator to increase the amount of data. The model should be able to recognise the images with some added noise on the picture.

My challenges can be boiled down to this:

  1. Small amount of data. I have < 100 images per class.
  2. My model is not supposed to find specific figures, but more overall patterns in the picture (areas with a certain colour and such).
  3. Many of the pictures contain a lot of white and text. Do I need any image preprocessing to help the model.

My model looks like this:

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(s1,s2,3), data_format = "channels_first"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(50, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4, activation='softmax'))

model.compile(loss='categorical_crossentropy',
          optimizer='Adam',
          metrics=['accuracy'])

And has img size of 250,250, and batch size 16.

Check acc and loss curves

acc curve

loss curve

Do you guys have any advice?

Thanks in advance!

HrMussa
  • 3
  • 3

1 Answers1

0

This is classical overfitting. You need to severely restrain your model and/or use transfer learning to combat this behavior. For restraining options you can increase dropout and add l2 regularization. In my experience l2 regularization really makes the problem hard for a NN. For the transfer learning you can use a resnet and just train the last 2-3 layers.

However, nothing beats having more data points though.

Thomas Pinetz
  • 6,948
  • 2
  • 27
  • 46
  • Hi there! Thanks a lot for your answer. I tried l2 regularization, but without any improvement. Transfer learning was, on the other hand, the way to go! I used the VGG16 application and my model improved dramatically. Thanks a lot! :-) – HrMussa Aug 10 '18 at 13:48