3

I need to define a method to be a custom gradient as follows:

class CustGradClass:

    def __init__(self):
        pass

    @tf.custom_gradient
    def f(self,x):
      fx = x
      def grad(dy):
        return dy * 1
      return fx, grad

I am getting the following error:

ValueError: Attempt to convert a value (<main.CustGradClass object at 0x12ed91710>) with an unsupported type () to a Tensor.

The reason is the custom gradient accepts a function f(*x) where x is a sequence of Tensors. And the first argument being passed is the object itself i.e., self.

From the documentation:

f: function f(*x) that returns a tuple (y, grad_fn) where:
x is a sequence of Tensor inputs to the function. y is a Tensor or sequence of Tensor outputs of applying TensorFlow operations in f to x. grad_fn is a function with the signature g(*grad_ys)

How do I make it work? Do I need to inherit some python tensorflow class?

I am using tf version 1.12.0 and eager mode.

jdehesa
  • 58,456
  • 7
  • 77
  • 121
Mr. Nobody
  • 185
  • 11

2 Answers2

4

This is one possible simple workaround:

import tensorflow as tf

class CustGradClass:

    def __init__(self):
        self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))

    @staticmethod
    def _f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.constant(1.0)
    c = CustGradClass()
    y = c.f(x)
    print(tf.gradients(y, x))
    # [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]

EDIT:

If you want to do this a lot of times on different classes, or just want a more reusable solution, you can use some decorator like this for example:

import functools
import tensorflow as tf

def tf_custom_gradient_method(f):
    @functools.wraps(f)
    def wrapped(self, *args, **kwargs):
        if not hasattr(self, '_tf_custom_gradient_wrappers'):
            self._tf_custom_gradient_wrappers = {}
        if f not in self._tf_custom_gradient_wrappers:
            self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
        return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
    return wrapped

Then you could just do:

class CustGradClass:

    def __init__(self):
        pass

    @tf_custom_gradient_method
    def f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

    @tf_custom_gradient_method
    def f2(self, x):
        fx = x * 2
        def grad(dy):
            return dy * 2
        return fx, grad
jdehesa
  • 58,456
  • 7
  • 77
  • 121
  • 1
    Great workaround! One suggestion: if you're going to wrap the method with a `lambda` anyway, there's no need to declare the method as static. Just say: `self.f = tf.custom_gradient(lambda x: self._f(x))` – Ben Price Jul 06 '20 at 02:13
  • 2
    One more note: TF 2.1+ supports using `tf.custom_gradients` with class methods out-of-the-box. (See [here](https://github.com/tensorflow/tensorflow/commit/f46392f4a51de2e4a95ce1bb1786603d7814e569).) It is not supported in TF 2.0, however. – Ben Price Jul 06 '20 at 02:34
1

In your example you are not using any member variables, so you could just make the method a static method. If you are using member variables then call the static method from a member function and pass the member variables as parameters.

class CustGradClass:

  def __init__(self):
    self.some_var = ...

  @staticmethod
  @tf.custom_gradient
  def _f(x):
    fx = x
    def grad(dy):
      return dy * 1

    return fx, grad

  def f(self):
    return CustGradClass._f(self.some_var)
kafman
  • 2,862
  • 1
  • 29
  • 51
  • I'd assume the OP is looking for a solution where they can use member variables within the method. – jdehesa Feb 22 '19 at 10:10