The keras examples directory contains a lightweight version of a stacked what-where autoencoder (SWWAE) which they train on MNIST data. (https://github.com/fchollet/keras/blob/master/examples/mnist_swwae.py)
In the original SWWAE paper, the authors compute the what and where using soft functions. However, in the keras implementation, they use a trick to get these locations. I would like to understand this trick.
Here is the code of the trick.
def getwhere(x):
''' Calculate the 'where' mask that contains switches indicating which
index contained the max value when MaxPool2D was applied. Using the
gradient of the sum is a nice trick to keep everything high level.'''
y_prepool, y_postpool = x
return K.gradients(K.sum(y_postpool), y_prepool) # How exactly does this line work?
Where y_prepool is a MxN matrix and y_postpool is a M/2 x N/2 matrix (lets assume canonical pooling of a size 2 pixels).
I have verified that the output of getwhere() is a bed of nails matrix where the nails indicate the position of the max (the local argmax if you will).
Can someone construct a small example demonstrating how getwhere works using this "Trick?"