0

I am trying to create a custom layer that is similar to Max Pooling or the first step of a separable convolution.

For example with a 2-Tensor in which I want to extract the non-overlapping 2x2 patches: if I have the [4,4] tensor

[[ 0, 1, 2, 3],
 [ 4, 5, 6, 7],
 [ 8, 9,10,11],
 [12,13,14,15]]

I want to end up with the following [2,2,4] Tensor

[[[ 0, 1, 4, 5],[ 2, 3, 6, 7]],
 [[ 8, 9,12,13],[10,11,14,15]]]

For a 3-Tensor, I want something similar but to also separate out the 3rd dimension. tf.extract_image_patches almost does what I want, but it folds the "depth" dimension into each patch.

Ideally if I had a tensor of shape [32,64,7] and wanted to extract all the [2,2] patches out of it: I would end up with a shape of [16,32,7,4]

To be clear, I just want to extract the patches, not to actually do max pooling nor separable convolution.

Since I am not actually augmenting the data, I suspect that you can do it with some tf.reshape trickery... Is there any nice way to achieve this in tensorflow without resorting to slicing+stitching/for loops?

Also, what is the correct terminology for this operation? Windowing? Tiling?

Ross
  • 567
  • 1
  • 4
  • 8

1 Answers1

0

Turns out this is really easy to do with tf.transpose. The solution that ended up working for me is:

#Assume x is in BHWC form
def pool(x,size=2):
  channels = x.get_shape()[-1]
  x = tf.extract_image_patches(
    x,
    ksizes=[1,size,size,1],
    strides=[1,size,size,1],
    rates=[1,1,1,1],
    padding="SAME"
  )
  x = tf.reshape(x,[-1],x.get_shape()[1:3]+[size**2,channels])
  x = tf.transpose(x,[0,1,2,4,3])
  return x
Ross
  • 567
  • 1
  • 4
  • 8