1

I am aware that numba can be used with Keras. However in my case I'm trying to subclass a Layer, so that solution doesn't work for me.

import numpy as np
import numba
import tensorflow as tf

@numba.jit(nopython = True)
def func(param, input):
    return param*input**2

@numba.jit(nopython = True)
def gradfunc(param, input):
    return input**2

@tf.custom_gradient
def func_tf(param, input):
    p = param.numpy()
    i = input.numpy()
    def grad(dy):
        return tf.numpy_function(gradfunc, (p, i), tf.float32), 2*p*i 
    return tf.numpy_function(func, (p, i), tf.float32), grad

class myLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def build(self, input_shape):
        self.param = self.add_weight("param")
        
    def call(self, input):
        return func_tf(self.param, input)
    
class myModel(tf.keras.Model):
    def __init__(self, num_layers):
        super().__init__(name='')
        self._layers = [myLayer() for _ in range(num_layers)]
        
    def call(self, input_tensor):
        for layer in self._layers:
            input_tensor = layer(input_tensor)
        return input_tensor
    
model = myModel(3)
print(model(1.5)) # <-- this works

This part is okay, because in eager mode .numpy() is allowed. However, training fails:

def loss(target, output):
    return tf.abs(tf.reduce_sum(target - output))**2

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=loss,
    metrics=[loss])

model.fit([0.1], [0.4], batch_size=None)

because model.fit uses @tf.function under the hood, so the calls to .numpy() in func and gradfunc are not allowed (see issue #40508 on GitHub).

How can I make it work?

Ziofil
  • 1,815
  • 1
  • 20
  • 30

1 Answers1

1

EDIT: Your code should work if instead of using .numpy() in tf_func you pass params and input directly to tf.numpy_function:

@tf.custom_gradient
def func_tf(param, input):
    param = tf.convert_to_tensor(param)
    input = tf.convert_to_tensor(input)
    def grad(dy):
        return tf.numpy_function(gradfunc, (param, input), tf.float32), 2 * param * input
    return tf.numpy_function(func, (param, input), tf.float32), grad

The tf.convert_to_tensor are there because tf.numpy_function expects strictly tf.Tensor objects, so if you directly use params, which will be a variable passed from myLayer, it will not work as expected.

For some reason, the code still gives an error about shapes after this. I got it to run properly changing the shape of the param weight to [1, 1]:

self.param = self.add_weight("param", shape=[1, 1])

You can pass run_eagerly=True to compile to force Keras to use eager mode (i.e. without tf.function) for training:

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=loss,
    metrics=[loss],
    run_eagerly=True)
jdehesa
  • 58,456
  • 7
  • 77
  • 121
  • 1
    That's true (thanks for the suggestion), but it slows down the training by a least 10x – Ziofil Jul 30 '20 at 11:50
  • @Ziofil Yes, doesn't really surprise me... Unfortunately there is not much way around it, if you want to incorporate NumPy/Numba code into the training you will need to run eagerly... The alternative is to port the Numba code to TensorFlow, which I don't know if would be feasible in your case (and it could be that the TensorFlow implementation is slower than the Numba one). – jdehesa Jul 30 '20 at 12:08
  • I see. But in principle it should be implementable at the level of the internals of Keras and TF: if one defines an op as a black box with its own custom gradient, it should be "pluggable" in the computational graph as is... – Ziofil Jul 30 '20 at 12:16
  • @Ziofil Wait, I just realised, why are you using `.numpy()` at all in your code? You don't need `p` and `i` in `func_tf`, you should be using `param` and `input`. That would allow the code to be wrapped into a `tf.function`. – jdehesa Jul 30 '20 at 13:36
  • Well, `input` is passed as a value, but `param` is passed as a Tensor, so if I don't use `.numpy()`, it doesn't work... (unless I'm missing something) – Ziofil Jul 30 '20 at 13:38
  • @Ziofil Updated with a working solution for your code, hope that's clear. – jdehesa Jul 30 '20 at 13:50