First off I know that I should use top_k
but what makes k-max pooling hard (to implement in TF) is that it has to preserve the order.
what I have so far:
import tensorflow as tf
from tensorflow.contrib.framework import sort
sess = tf.Session()
a = tf.convert_to_tensor([[[5, 1, 10, 2], [3, 11, 2, 6]]])
b = sort(tf.nn.top_k(a, k=2)[1])
print(tf.gather(a, b, axis=-1).eval(session=sess))
it's close but not there yet
what I get:
[[[[[ 5, 10], [ 1, 2]]], [[[ 3, 2], [11, 6]]]]]
what I want:
[[[5, 10], [11, 6]]]
I am almost hundred percent sure that gather_nd is required but I can't figure that out, also I am a pytorch user and it's really easy there
import torch
a = torch.LongTensor([[[5, 1, 10, 2], [3, 11, 2, 6]]])
b = a.topk(2, dim = -1)[1].sort(dim = -1)[0]
print(a.gather(-1, b))
Oh and also every code that I found was not an order preserving one(which is semantically wrong)