I am trying to create a Sign layer in tensorflow. Because the gradient of Sign is zero, I want to create a function like this:
but I don't know where should write backward function?
I am trying to create a Sign layer in tensorflow. Because the gradient of Sign is zero, I want to create a function like this:
but I don't know where should write backward function?
In general, this can be done via tf.custom_gradient
. This allows you to write a forward function along with a custom gradient function which depends on the "incoming" gradients (from the layers further down the model) as well as the forward function.
With it, you could create a function like this:
@tf.custom_gradient
def sign_with_grad(x):
output = tf.sign(x)
def grad_fn(dy):
check_range = tf.where(tf.less_equal(tf.abs(x), 1.), 1., 0.)
return dy*check_range
return output, grad_fn
Here, we write a custom gradient that simply passes through that of the layer above -- except that it zeros out all elements where the input is outside the range [-1, 1]. We return both the result of the forward function (the sign) as well as the gradient function itself. The decorator takes care of handling the rest.
Please note, I did not check whether the code runs -- let me know if it doesn't! In particular, the less_equal
and or where
checks might need explicit broadcasting -- e.g. use tf.ones_like(x)
instead of 1.
(same for 0).