0

The keras examples directory contains a lightweight version of a stacked what-where autoencoder (SWWAE) which they train on MNIST data. (https://github.com/fchollet/keras/blob/master/examples/mnist_swwae.py)

In the original SWWAE paper, the authors compute the what and where using soft functions. However, in the keras implementation, they use a trick to get these locations. I would like to understand this trick.

Here is the code of the trick.

def getwhere(x):
    ''' Calculate the 'where' mask that contains switches indicating which
    index contained the max value when MaxPool2D was applied.  Using the
    gradient of the sum is a nice trick to keep everything high level.'''
    y_prepool, y_postpool = x
    return K.gradients(K.sum(y_postpool), y_prepool)  # How exactly does this line work?

Where y_prepool is a MxN matrix and y_postpool is a M/2 x N/2 matrix (lets assume canonical pooling of a size 2 pixels).

I have verified that the output of getwhere() is a bed of nails matrix where the nails indicate the position of the max (the local argmax if you will).

Can someone construct a small example demonstrating how getwhere works using this "Trick?"

Avedis
  • 443
  • 3
  • 13

1 Answers1

2

Lets focus on the simplest example, without really talking about convolutions, say we have a vector

x = [1 4 2]

which we max-pool over (with a single, big window), we get

mx = 4

mathematically speaking, it is:

mx = x[argmax(x)]

now, the "trick" to recover one hot mask used by pooling is to do

magic = d mx / dx

there is no gradient for argmax, however it "passes" the corresponding gradient to an element in a vector at the location of maximum element, so:

d mx / dx = [0/dx[1] dx[2]/dx[2] 0/dx[3]] = [0 1 0]

as you can see, all the gradient for non-maximum elements are zero (due to argmax), and "1" appears at the maximum value because dx/x = 1.

Now for "proper" maxpool you have many pooling regions, connected to many input locations, thus taking analogous gradient of sum of pooled values, will recover all the indices.

Note however, that this trick will not work if you have heavily overlapping kernels - you might end up with bigger values than "1". Basically if a pixel is max-pooled by K kernels, than it will have value K, not 1, for example:

     [1 ,2, 3]
x =  [13,3, 1]
     [4, 2, 9]

if we max pool with 2x2 window we get

mx = [13,3]
     [13,9]

and the gradient trick gives you

        [0, 0, 1]
magic = [2, 0, 0]
        [0, 0, 1]
lejlot
  • 64,777
  • 8
  • 131
  • 164
  • @lejot Thanks for taking the time to write this up despite it being downvoted (its a great writeup). One question, I don't quite understand the K.sum(). I don't quite follow, "Now, for 'proper' maxpool...." Can you explain another way? – Avedis Aug 25 '17 at 13:32
  • What I described works for a single maxpool window. Now, you have many of these, and since gradient of the sum is sum of the gradients, it is a natural way to express many such "calls" jointly. You could also, instead of doing K.sum(), independently ask for a gradient of each max-pooled region, and then add all the masks (this is how you end up with values bigger than 1). Is this clearer? – lejlot Aug 25 '17 at 20:40