Here is one based on the lua version (there is a pytorch impl but i think that has an error taking the average of max+min). I'm assuming the lua version's avg of top max and min values was still correct. I've not tested the whole custom layer aspects but close enough to get something going, comments welcomed.
Tony
class WeldonPooling(Layer):
"""Class to implement Weldon selective spacial pooling with negative evidence
"""
#@interfaces.legacy_global_pooling_support
def __init__(self, kmax, kmin=-1, data_format=None, **kwargs):
super(WeldonPooling, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
self.kmax=kmax
self.kmin=kmin
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_last':
return (input_shape[0], input_shape[3])
else:
return (input_shape[0], input_shape[1])
def get_config(self):
config = {'data_format': self.data_format}
base_config = super(_GlobalPooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self.data_format == "channels_last":
inputs = tf.transpose(inputs, [0, 3, 1, 2])
kmax=self.kmax
kmin=self.kmin
shape=tf.shape(inputs)
batch_size = shape[0]
num_channels = shape[1]
h = shape[2]
w = shape[3]
n = h * w
view = tf.reshape(inputs, [batch_size, num_channels, n])
sorted, indices = tf.nn.top_k(view, n, sorted=True)
#indices_max = tf.slice(indices,[0,0,0],[batch_size, num_channels, kmax])
output = tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,0],[batch_size, num_channels, kmax]),2),kmax)
if kmin > 0:
#indices_min = tf.slice(indices,[0,0, n-kmin],[batch_size, num_channels, kmin])
output=tf.add(output,tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,n-kmin],[batch_size, num_channels, kmin]),2),kmin))
return tf.reshape(output,[batch_size, num_channels])