15

I am trying to write a Lambda layer in Keras which calls a function connection, that runs a loop for i in range(0,k) where k is fed in as an input to the function, connection(x,k). Now, when I try to call the function in the Functional API, I tried using:

k = 5
y = Lambda(connection)(x)

Also,

y = Lambda(connection)(x,k)

But neither of those approaches worked. How can I feed in the value of k without assigning it as a global parameter?

Prabaha
  • 879
  • 2
  • 9
  • 19
  • Is "k" a constant? Or is it calculated somewhere in the model? Is it an input to the model, as part of the input data? – Daniel Möller Jul 05 '17 at 20:42
  • 1
    `k` updates through the model. The value of `k` changes for different times I call the `Lambda` layer. But I found the solution [here](https://github.com/fchollet/keras/pull/1911), in a Keras GitHub Issue. Using `y = Lambda(connection, arguments={'k':k})(x)` worked! – Prabaha Jul 05 '17 at 22:39

3 Answers3

21

Just use

y = Lambda(connection)((x,k)) 

and then var[0], var[1] in connection method

Andrey Nikishaev
  • 3,759
  • 5
  • 40
  • 55
15

Found the solution to the problem in this GitHub Pull Request. Using

y = Lambda(connection, arguments={'k':k})(x)

worked!

Prabaha
  • 879
  • 2
  • 9
  • 19
  • Note: by default `model.save_weights()` will not recognize `k` as part of the model. – wrongu Nov 02 '17 at 19:34
  • 2
    Yes, this worked indeed. But for the other readers.. please know that the variable k here is already defined as some constant – zwep Oct 30 '18 at 09:11
4
Tmodel = Sequential()
x = layers.Input(shape=[1,])   # Lambda on single input
out1 = layers.Lambda(lambda x: x ** 2)(x)

y = layers.Input(shape=[1,])   # Lambda on multiple inputs
z = layers.Input(shape=[1,])
def conn(IP):
    return IP[0]+IP[1]
out2 = layers.Lambda(conn)([y,z])

Tmodel = tf.keras.Model(inputs=[x,y,z], outputs=[out1,out2],name='Tmodel')  # Define Model
Tmodel.summary()

# output
O1,O2 = Tmodel([2,15,10])
print(O1)   # tf.Tensor(4, shape=(), dtype=int32)
print(O2)   # tf.Tensor(25, shape=(), dtype=int32)