3

I'm trying to replace swish activation with relu activation in pretrained TF model EfficientNetB0. EfficientNetB0 uses swish activation in Conv2D and Activation layers. This SO post is very similar to what I'm looking for. I also found an answer which works for models without skip connection. Below is the code:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import ReLU

def replace_swish_with_relu(model):
    '''
    Modify passed model by replacing swish activation with relu
    '''
    for layer in tuple(model.layers):
        layer_type = type(layer).__name__
        if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
            print(layer_type, layer.activation.__name__)
            if layer_type == "Conv2D":
                # conv layer with swish activation.
                # Do something
                layer.activation = ReLU() # This didn't work
            else:
                # activation layer
                # Do something
                layer = tf.keras.layers.Activation('relu', name=layer.name + "_relu") # This didn't work
    return model

# load pretrained efficientNet
model = tf.keras.applications.EfficientNetB0(
    include_top=True, weights='imagenet', input_tensor=None,
    input_shape=(224, 224, 3), pooling=None, classes=1000,
    classifier_activation='softmax')

# convert swish activation to relu activation
model = replace_swish_with_relu(model)
model.save("efficientNet-relu")

How to modify replace_swish_with_relu to replace swish activations with relu in the passed model?

Thank you for any pointers/help.

mrtpk
  • 1,398
  • 4
  • 18
  • 38

2 Answers2

3

layer.activation points to tf.keras.activations.swish function address. We can modify it to point to tf.keras.activations.relu. Below is the modified, replace_swish_with_relu:

def replace_swish_with_relu(model):
    '''
    Modify passed model by replacing swish activation with relu
    '''
    for layer in tuple(model.layers):
        layer_type = type(layer).__name__
        if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
            print(layer_type, layer.activation.__name__)
            if layer_type == "Conv2D":
                # conv layer with swish activation
                layer.activation = tf.keras.activations.relu
            else:
                # activation layer
                layer.activation = tf.keras.activations.relu
    return model

Note: If you are modifying the activation function, then you need to retrain the model to work with the new activation. Related.

mrtpk
  • 1,398
  • 4
  • 18
  • 38
0

Try this:

def replace_swish_with_relu(model):
    '''
    Modify passed model by replacing swish activation with relu
    '''
    for i,layer in enumerate(tuple(model.layers)):
        layer_type = type(layer).__name__
        if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
            print(layer_type, layer.activation.__name__)
            if layer_type == "Conv2D":
                # conv layer with swish activation.
                # Do something
                model.layers[i] = ReLU() # This didn't work
            else:
                # activation layer
                # Do something
                model.layers[i] = tf.keras.layers.Activation('relu', name=layer.name + "_relu") # This didn't work
    return model

DachuanZhao
  • 1,181
  • 3
  • 15
  • 34