8

I would like to implement in TensorFlow the technique of "Guided back-propagation" introduced in this Paper and which is described in this recipe .

Computationally that means that when I compute the gradient e.g., of the input wrt. the output of the NN, I will have to modify the gradients computed at every RELU unit. Concretely, the back-propagated signal on those units must be thresholded on zero, to make this technique work. In other words the partial derivative of the RELUs that are negative must be ignored.

Given that I am interested in applying these gradient computations only on test examples, i.e., I don't want to update the model's parameters - how shall I do it?

I tried (unsuccessfully) two things so far:

  1. Use tf.py_func to wrap my simple numpy version of a RELU, which then is eligible to redefine it's gradient operation via the g.gradient_override_map context manager.

  2. Gather the forward/backward values of BackProp and apply the thresholding on those stemming from Relus.

I failed with both approaches because they require some knowledge of the internals of TF that currently I don't have.

Can anyone suggest any other route, or sketch the code?

Thanks a lot.

mrry
  • 125,488
  • 26
  • 399
  • 400
Peter
  • 1,541
  • 3
  • 11
  • 11

2 Answers2

7

The better solution (your approach 1) with ops.RegisterGradient and tf.Graph.gradient_override_map. Together they override the gradient computation for a pre-defined Op, e.g. Relu within the gradient_override_map context using only python code.

@ops.RegisterGradient("GuidedRelu")
def _GuidedReluGrad(op, grad):
    return tf.where(0. < grad, gen_nn_ops._relu_grad(grad, op.outputs[0]), tf.zeros(grad.get_shape()))

...
with g.gradient_override_map({'Relu': 'GuidedRelu'}):
    y = tf.nn.relu(x)

here is the full example implementation of guided relu: https://gist.github.com/falcondai/561d5eec7fed9ebf48751d124a77b087

Update: in Tensorflow >= 1.0, tf.select is renamed to tf.where. I updated the snippet accordingly. (Thanks @sbond for bringing this to my attention :)

Falcon
  • 1,317
  • 1
  • 13
  • 30
  • 3
    beware that you need to wrap the graph construction involving the relu op *inside* the `gradient_override_map` context. – Falcon Aug 10 '16 at 04:26
  • 1
    Thank you, @Falcon, this works well. I also had to replace tf.select by tf.where, as I am using TF version 1.2. – sbond Oct 18 '17 at 18:47
  • @sbond Thanks for the update. I edited my post to include your comment. – Falcon Jan 05 '18 at 01:46
6

The tf.gradients has the grad_ys parameter that can be used for this purpose. Suppose your network has just one relu layer as follows :

before_relu = f1(inputs, params)
after_relu = tf.nn.relu(before_relu)
loss = f2(after_relu, params, targets)

First, compute the derivative up to after_relu.

Dafter_relu = tf.gradients(loss, after_relu)[0]

Then threshold your gradients that you send down.

Dafter_relu_thresholded = tf.select(Dafter_relu < 0.0, 0.0, Dafter_relu)

Compute the actual gradients w.r.t to params.

Dparams = tf.gradients(after_relu, params, grad_ys=Dafter_relu_thresholded)

You can easily extend this same method for a network with many relu layers.

keveman
  • 8,427
  • 1
  • 38
  • 46
  • Hi Kaveman, thanks a lot for the prompt reply. In your last tf.gradients call you are intentionally passing the Dafter_relu as the first argument? – Peter Jul 13 '16 at 02:57
  • 1
    Also, I am still confused wrt. how to generalize this on a network with many layers in an way that works for any NN that has RELU elements. For this part wouldn't you need to trace all the input/output of each RELU element and 'chain' your previously described logic? Thanks. – Peter Jul 13 '16 at 03:04
  • @Peter, sorry, that was a typo. The second call to `tf.gradients` is `after_relu` w.r.t. `params`. – keveman Jul 13 '16 at 03:55
  • 1
    Yes, to generalize, you have to identify all the relu layers, and run the above logic on each of the layers. – keveman Jul 13 '16 at 03:57
  • @keveman Could you please elaborate on what do functions f1 and f2 compute essentially how do I calculate the loss?For example Loss at the final layer is cross entropy then what would be loss in any intermediate CNN layer. – Shubham Singh rawat Mar 23 '18 at 21:11