2

First off I know that I should use top_k but what makes k-max pooling hard (to implement in TF) is that it has to preserve the order.

what I have so far:

import tensorflow as tf
from tensorflow.contrib.framework import sort
sess = tf.Session()
a = tf.convert_to_tensor([[[5, 1, 10, 2], [3, 11, 2, 6]]])
b = sort(tf.nn.top_k(a, k=2)[1])
print(tf.gather(a, b, axis=-1).eval(session=sess))

it's close but not there yet

what I get:

[[[[[ 5, 10], [ 1, 2]]], [[[ 3, 2], [11, 6]]]]]

what I want:

[[[5, 10], [11, 6]]]

I am almost hundred percent sure that gather_nd is required but I can't figure that out, also I am a pytorch user and it's really easy there

import torch
a = torch.LongTensor([[[5, 1, 10, 2], [3, 11, 2, 6]]])
b = a.topk(2, dim = -1)[1].sort(dim = -1)[0]
print(a.gather(-1, b))

Oh and also every code that I found was not an order preserving one(which is semantically wrong)

Melike
  • 468
  • 1
  • 7
  • 15
Separius
  • 1,226
  • 9
  • 24

2 Answers2

2

Weird... this should be super easy and yet we can't find a ready solution...

Try this:

sess = tf.Session()
k = 2
a = tf.convert_to_tensor([[[5, 1, 10, 2], [3, 11, 2, 6]]])
b = tf.nn.top_k(a, k=k, sorted=True)[1]
b = sort(b)

flatA = tf.reshape(a,(-1,))
shapeA = tf.shape(a)
lenA = tf.shape(flatA)[0]
kShape = tf.concat([shapeA[:-1],tf.constant([k])], axis=-1)

indices = tf.range(lenA)
indices = tf.reshape(indices,shapeA)

toSum = tf.expand_dims(tf.gather(indices,0,axis=-1), axis=-1)
b += toSum
b = tf.reshape(b,(-1,))

gat = tf.gather(flatA, b)
gat = tf.reshape(gat, kShape)

print(gat.eval(session=sess))
Daniel Möller
  • 84,878
  • 18
  • 192
  • 214
0

Try this Custom layer just modified the above code into a custom tensorflow keras layer.

class KMaxPooling(layers.Layer):
"""
K-max pooling layer that extracts the k-highest activations from a sequence (2nd dimension).
TensorFlow backend.
"""
def __init__(self, k=1, axis=1, **kwargs):
    super(KMaxPooling, self).__init__(**kwargs)
    self.input_spec = layers.InputSpec(ndim=3)
    self.k = k

    assert axis in [1,2],  'expected dimensions (samples, filters, convolved_values),\
               cannot fold along samples dimension or axis not in list [1,2]'
    self.axis = axis

    # need to switch the axis with the last elemnet
    # to perform transpose for tok k elements since top_k works in last axis
    self.transpose_perm = [0,1,2] #default
    self.transpose_perm[self.axis] = 2
    self.transpose_perm[2] = self.axis

def compute_output_shape(self, input_shape):
    input_shape_list = list(input_shape)
    input_shape_list[self.axis] = self.k
    return tuple(input_shape_list)

def call(self, x):
    # swap sequence dimension to get top k elements along axis=1
    transposed_for_topk = tf.transpose(x, perm=self.transpose_perm)

    # extract top_k, returns two tensors [values, indices]
    top_k_vals, top_k_indices = tf.math.top_k(transposed_for_topk,
                                              k=self.k, sorted=True,
                                              name=None)
    # maintain the order of values as in the paper
    # sort indices
    sorted_top_k_ind = tf.sort(top_k_indices)
    flatten_seq = tf.reshape(transposed_for_topk, (-1,))
    shape_seq = tf.shape(transposed_for_topk)
    len_seq = tf.shape(flatten_seq)[0]
    indices_seq = tf.range(len_seq)
    indices_seq = tf.reshape(indices_seq, shape_seq)
    indices_gather = tf.gather(indices_seq, 0, axis=-1)
    indices_sum = tf.expand_dims(indices_gather, axis=-1)
    sorted_top_k_ind += indices_sum
    k_max_out = tf.gather(flatten_seq, sorted_top_k_ind)
    # return back to normal dimension but now sequence dimension has only k elements
    # performing another transpose will get the tensor back to its original shape
    # but will have k as its axis_1 size
    transposed_back = tf.transpose(k_max_out, perm=self.transpose_perm)

    return transposed_back
Vineet Suryan
  • 101
  • 1
  • 9