I am trying to implement another pooling function for neural network with Theano, expect of already existing maxpool, for example average pool.
Using to this source, where average pooling is already implemented, my code looks like:
Random initialization just to test:
invals = numpy.random.RandomState(1).rand(3,2,5,5)
Definition of Theano scalars and functions:
pdim = T.scalar('pool dim', dtype='float32')
pool_inp = T.tensor4('pool input', dtype='float32')
pool_sum = TSN.images2neibs(pool_inp, (pdim, pdim))
pool_out = pool_sum.mean(axis=-1)
pool_fun = theano.function([pool_inp, pdim], pool_out, name = 'pool_fun', allow_input_downcast=True)
TSN is theano.sandbox.neighbours
And the call of the function:
pool_dim = 2
temp = pool_fun(invals, pool_dim)
temp.shape = (invals.shape[0], invals.shape[1], invals.shape[2]/pool_dim,
invals.shape[3]/pool_dim)
print ('invals[1,0,:,:]=\n', invals[1,0,:,:])
print ('output[1,0,:,:]=\n',temp[1,0,:,:])
And I am getting an error:
TypeError: neib_shape[0]=2, neib_step[0]=2 and ten4.shape[2]=5 not consistent
Apply node that caused the error: Images2Neibs{valid}(pool input, MakeVector.0, MakeVector.0)
Inputs shapes: [(3, 2, 5, 5), (2,), (2,)]
Inputs strides: [(200, 100, 20, 4), (4,), (4,)]
Inputs types: [TensorType(float32, 4D), TensorType(float32, vector), TensorType(float32, vector)]
Use the Theano flag 'exception_verbosity=high' for a debugprint of this apply node.
I don't really understand this error. Would be glad to have any suggestions how to correct this error or example of other pooling techniques, programmed in Theano.
Thanks!
Edit: with the ignoring the border, it works perfectly
pool_sum = TSN.images2neibs(pool_inp, (pdim, pdim), mode='ignore_borders')
invals[1,0,:,:]=
[[ 0.01936696 0.67883553 0.21162812 0.26554666 0.49157316]
[ 0.05336255 0.57411761 0.14672857 0.58930554 0.69975836]
[ 0.10233443 0.41405599 0.69440016 0.41417927 0.04995346]
[ 0.53589641 0.66379465 0.51488911 0.94459476 0.58655504]
[ 0.90340192 0.1374747 0.13927635 0.80739129 0.39767684]]
output[1,0,:,:]=
[[ 0.33142066 0.30330223]
[ 0.42902038 0.64201581]]