6

I want to initialize the convolution layer by a specific kernel which is not defined in Keras. For instance, if I define the below function to initialize the kernel:

def init_f(shape):
      ker=np.zeros((shape,shape))
      ker[int(np.floor(shape/2)),int(np.floor(shape/2))]=1
      return ker

And the convolution layer is designed as follows:

model.add(Conv2D(filters=32, kernel_size=(3,3),
                      kernel_initializer=init_f(3)))

I get the error:

Could not interpret initializer identifier

I have followed a similar issue at: https://groups.google.com/forum/#!topic/keras-users/J46pplO64-8 But I could not adapt it to my code. Could you please help me to define the arbitrary kernel in Keras?

nbro
  • 15,395
  • 32
  • 113
  • 196

1 Answers1

9

A few items to fix. Let's start with the kernel initializer. From the documentation:

If passing a custom callable, then it must take the argument shape (shape of the variable to initialize) and dtype (dtype of generated values)

So the signature should become:

def init_f(shape, dtype=None)

The function will work without the dtype, but it's good practice to keep it there. That way you can specify the dtype to calls inside your function, e.g.:

np.zeros(shape, dtype=dtype)

This also addresses your second issue: the shape argument is a tuple, so you just need to pass it straight to np.zeros and don't need to make another tuple.

I'm guessing you're trying to initialize the kernel with a 1 in the middle, so you could also generalize your function to work with whatever shape it receives:

ker[tuple(map(lambda x: int(np.floor(x/2)), ker.shape))]=1

Putting it all together:

def init_f(shape, dtype=None):
    ker = np.zeros(shape, dtype=dtype)
    ker[tuple(map(lambda x: int(np.floor(x/2)), ker.shape))]=1
    return ker

One last problem. You need to pass the function to the layer, not the result of the call:

model.add(Conv2D(filters=32, kernel_size=(3,3),
                  kernel_initializer=init_f))

The layer function will pass the arguments to init_f.

ggallo
  • 348
  • 1
  • 11