0

I have a Layer that computes the mean of timesteps and supports masking. My problem is that there may be the case that the mask is empty (no padded timesteps) but i don't know how to check for zeros when working with tensors.

I have a few training examples for which the mask is empty so i get a NaN loss and the program crashes.

This is my Layer:

class MeanOverTime(Layer):
    def __init__(self, **kwargs):
        self.supports_masking = True
        super(MeanOverTime, self).__init__(**kwargs)

    def call(self, x, mask=None):
        if mask is not None:
            return K.cast(x.sum(axis=1) / mask.sum(axis=1, keepdims=True), K.floatx()) # this may result to division by zero
        else:
            return K.mean(x, axis=1)

    def get_output_shape_for(self, input_shape):
        return input_shape[0], input_shape[-1]

    def compute_mask(self, input, input_mask=None):
        return None

This mask.sum(axis=1, keepdims=True) becomes zero. In order to bypass this i have increased the input_length so it covers all my training examples, but this is not a solution. Also i tried adding a try/except but this also didn't work.

Christos Baziotis
  • 5,845
  • 16
  • 59
  • 80

1 Answers1

1

try/except wont work because all this piece of code does is create the symbolic tensor graph which has no exception .. the evaluation hence the division by 0 happens in the fit/evaluate/predict function. You need to include the logic/decision in the symbolic graph.

You can use switch(condition, then_expression, else_expression) to include if and else:

def call(self, x, mask=None):
    if mask is not None:
        sum = mask.sum(axis=1, keepdims=True)
        cond = K.equal(sum,0)
        _the_other_tensor_ = ....
        div = K.switch(cond, _the_other_tensor_ ,sum)
        return K.cast(x.sum(axis=1) / div, K.floatx()) # this may result to division by zero
    else:
        return K.mean(x, axis=1)

Or just use clip(x, min_value, max_value) to clip with a very small number epsilon to make the division numerically stable.

def call(self, x, mask=None):
    if mask is not None:
        sum = mask.sum(axis=1, keepdims=True)
        div = K.clip(sum, K.epsilon, 1)
        return K.cast(x.sum(axis=1) / div, K.floatx()) # this may result to division by zero
    else:
        return K.mean(x, axis=1)
indraforyou
  • 8,969
  • 3
  • 49
  • 40