4

I'm trying to create an activation function in Keras that can take in a parameter beta like so:

from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation

class Swish(Activation):

    def __init__(self, activation, beta, **kwargs):
        super(Swish, self).__init__(activation, **kwargs)
        self.__name__ = 'swish'
        self.beta = beta


def swish(x):
    return (K.sigmoid(beta*x) * x)

get_custom_objects().update({'swish': Swish(swish, beta=1.)})

It runs fine without the beta parameter, but how can I include the parameter in the activation definition? I also want this value to be saved when I do model.to_json() like for ELU activation.


Update: I wrote the following code based on @today's answer:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)
        self.__name__ = 'swish'

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
    arch_file.write(arch)

However, it does not currently save the beta value in the .json file. How can I make it save the value?

user7867665
  • 852
  • 7
  • 25

1 Answers1

7

Since you want to save the parameters of activation function when serializing the model, I think it is better to define the activation function as a layer like the advanced activations which have been defined in Keras. You can do it like this:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

Then you can use it the same way you use a Keras layer:

# ...
model.add(Swish(beta=0.3))

Since get_config() method has been implemented in its definition, the parameter beta would be saved when using methods like to_json() or save().

today
  • 32,602
  • 8
  • 95
  • 115
  • This is what I do but the parameter value is not saved in the json file – user7867665 Oct 29 '18 at 23:37
  • @user7867665 Are you sure you have implemented `get_config()` method and include `beta` parameter in it? – today Oct 29 '18 at 23:38
  • I did it differently, I am testing your implementation now – user7867665 Oct 29 '18 at 23:41
  • it doesn't save the beta value in the .json file, I used exactly your code – user7867665 Oct 29 '18 at 23:50
  • @user7867665 Strange indeed! It works for me. Could you please put your code in a [github gist](https://gist.github.com/) (or any other online note sharing website which gives shareable link) and give the link to me? – today Oct 29 '18 at 23:53
  • How do I give you the link, is there direct messages? – user7867665 Oct 30 '18 at 00:15
  • @user7867665 Just paste it in the comment here. – today Oct 30 '18 at 00:31
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/182785/discussion-between-user7867665-and-today). – user7867665 Oct 30 '18 at 10:53
  • @user7867665 You are using `generator.to_json()` but your model is stored in `gnn`?! Plus, it is not a good idea to edit your question by removing the original question. Add any more info at the end instead. Hence, I rolled back your edit and modified as such. – today Oct 30 '18 at 15:05
  • It's `gnn.to_json()`, I pasted the wrong thing. Fixed. Problem is still the same. Thanks for the edit – user7867665 Oct 30 '18 at 17:03
  • @user7867665 Run [this code](https://www.pastiebin.com/5bd89a20d1489) on your machine and confirm that you see the `beta` parameter in the printed config. Further, look at how the custom object is passed to `load_model` function. – today Oct 30 '18 at 17:55
  • It works! So why doesn't it work for my model? Maybe because it was not trained using this definition of swish? – user7867665 Oct 31 '18 at 11:07
  • So it didn't work on my model because it was not trained using this definition of swish. Thanks a lot! – user7867665 Oct 31 '18 at 11:14