0

I'm trying to learn Image classification and have got some models working on datasets that I have generated. I've used ResNet50, EfficientNetB3 and B6 for transfer learning and I have a working model at the end but would like to improve their accuracy by Fine-tuning my model.

This is where I'm running into memory problems, Looking at other people with similar problems I don't think I should be running into memory problems. I'm using Colab pro and have briefly paid for a 40Gb GPU and it still says OOM when running ResNet50 or EfficientNetB3.

I built my model like this:

#Set image dataset location, image size and batch size 
data_root = ("/content/FOD/DATA/Train/")
TRAINING_DATA_DIR = str(data_root)
image_size = (400, 400)
batch_size = 16

#Split the dataset into train and validation sets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    TRAINING_DATA_DIR,
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    TRAINING_DATA_DIR,
    validation_split=0.2,
    subset="validation",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)

base_model = keras.applications.ResNet50(
    weights = 'imagenet',  # Load weights pre-trained on ImageNet.
    input_shape = (400, 400, 3),
    pooling = 'max',
    include_top = False)  # Do not include the ImageNet classifier at the top.

base_model.trainable = False

inputs = keras.Input(shape=(400, 400, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
# x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dropout(0.2)(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)

model.compile(optimizer=keras.optimizers.Adam(),
              loss = keras.losses.binary_crossentropy,
              metrics=[keras.metrics.BinaryAccuracy()])

keras_callbacks   = [
      EarlyStopping(monitor='val_loss', patience=2, mode='min', min_delta=0.0001),
      ModelCheckpoint(filepath="/content/gdrive/MyDrive/MODELS/ResNet50-4Items-CMH/", save_weights_only=False, monitor='val_loss', save_best_only=True, mode='min')
]

hist = model.fit(train_ds, epochs=30, callbacks = [keras_callbacks], validation_data=val_ds

This all seems to run ok, and gives me a model at the end, My code to try to unfreeze layers and fine-tune is:

###############################Fine-Tuning##################################################
# Load the model you wish to fine-tune
tf.keras.backend.clear_session()
model = tf.keras.models.load_model('*link to drive location of my model*')
##View the model layers
model.summary()

#Set the model to save the best epoches with highest validation accuracy
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="*link to save location of my model*",
    save_weights_only=False,
    monitor='val_binary_accuracy',
    mode='max',
    save_best_only=True)
#Train the model with only 2 epochs and a batch size of 16 to decrease memory usage
#model.fit(train_ds, epochs=2, batch_size = 8, callbacks = [model_checkpoint_callback], validation_data=val_ds)

def unfreeze_model(model):
    # We unfreeze the top 20 layers while leaving BatchNorm layers frozen
    for layer in model.layers[-10:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True

unfreeze_model(model)

##Set the learning rate to be low so as to only make small changes to an already good model
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(optimizer=optimizer, loss=keras.losses.binary_crossentropy, metrics=[keras.metrics.BinaryAccuracy()])
model.summary()

hist = model.fit(train_ds, epochs=10, batch_size=2, callbacks = [model_checkpoint_callback], validation_data=val_ds)

I've tried reducing batch size, etc and have also moved to what the internet calls smaller models, but I am always running into OOM when fine-tuning.

The output of the fine-tuning is:

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   

=================================================================

 input_2 (InputLayer)        [(None, 400, 400, 3)]     0         
                                                                 
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 flatten (Flatten)           (None, 2048)              0         
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense (Dense)               (None, 1)                 2049      
                                                                 
=================================================================
Total params: 23,589,761
Trainable params: 2,049
Non-trainable params: 23,587,712
_________________________________________________________________
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   

=================================================================

 input_2 (InputLayer)        [(None, 400, 400, 3)]     0         
                                                                 
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 flatten (Flatten)           (None, 2048)              0         
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense (Dense)               (None, 1)                 2049      
                                                                 
=================================================================
Total params: 23,589,761
Trainable params: 23,536,641
Non-trainable params: 53,120

I have tried to reduce by trainable params, but it seems to either be 23mil or 2 thousand, with nothing in between when following tutorials online.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Chris
  • 1

0 Answers0