4

I am trying to implement a discriminative loss function for instance segmentation of images based on this paper: https://arxiv.org/pdf/1708.02551.pdf (This link is just for the readers' reference; I don't expect anyone to read it to help me out!)

My problem: Once I move from a simple loss function to a more complicated one (like you see in the attached code snippet), the loss function zeroes out after the first epoch. I checked the weights, and almost all of them seem to hover closely around -300. They are not exactly identical, but very close to each other (differing only in the decimal places).

Relevant code that implements the discriminative loss function:

def regDLF(y_true, y_pred):
    global alpha
    global beta
    global gamma
    global delta_v
    global delta_d
    global image_height
    global image_width
    global nDim

    y_true = tf.reshape(y_true, [image_height*image_width])

    X = tf.reshape(y_pred, [image_height*image_width, nDim])
    uniqueLabels, uniqueInd = tf.unique(y_true)

    numUnique = tf.size(uniqueLabels)

    Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
    ones_Sigma = tf.ones((tf.shape(X)[0], 1))
    ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
    mu = tf.divide(Sigma, ones_Sigma)

    Lreg = tf.reduce_mean(tf.norm(mu, axis = 1))

    T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis = 1)
    T = tf.divide(T, Lreg)
    T = tf.subtract(T, delta_v)
    T = tf.clip_by_value(T, 0, T)
    T = tf.square(T)

    ones_Sigma = tf.ones_like(uniqueInd, dtype = tf.float32)
    ones_Sigma = tf.unsorted_segment_sum(ones_Sigma,uniqueInd, numUnique)
    clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
    clusterSigma = tf.divide(clusterSigma, ones_Sigma)

    Lvar = tf.reduce_mean(clusterSigma, axis = 0)

    mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
    mu_band_rep = tf.tile(mu, [1, numUnique])
    mu_band_rep = tf.reshape(mu_band_rep, (numUnique*numUnique, nDim))

    mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
    mu_diff = tf.norm(mu_diff, axis = 1)
    mu_diff = tf.divide(mu_diff, Lreg)

    mu_diff = tf.subtract(2*delta_d, mu_diff)
    mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
    mu_diff = tf.square(mu_diff)

    numUniqueF = tf.cast(numUnique, tf.float32)
    Ldist = tf.reduce_mean(mu_diff)        

    L = alpha * Lvar + beta * Ldist + gamma * Lreg

    return L

Question: I know it's hard to understand what the code does without reading the paper, but I have a couple questions:

  1. Is there something glaringly wrong with the loss function defined above?

  2. Anyone has a general idea as to why the loss function could zero out after the first epoch?

Thank you very much for your time and help!

Shai
  • 111,146
  • 38
  • 238
  • 371
  • it appears as if your loss is composed of three terms. Why not change the weight of the three terms and see which one is the problematic? – Shai Sep 24 '17 at 06:21

2 Answers2

1

I think your problem suffers from tf.norm which is not safe (leads to zeros somewhere in the vector and hence nan in its gradients). It would be better to replace tf.norm by this custom function:

def tf_norm(inputs, axis=1, epsilon=1e-7,  name='safe_norm'):
    squared_norm    = tf.reduce_sum(tf.square(inputs), axis=axis, keep_dims=True)
    safe_norm       = tf.sqrt(squared_norm+epsilon)
    return tf.identity(safe_norm, name=name)
0

In your Ldist calculation you use tf.tile and tf.reshape to find the distance between different cluster means in the following manner (suppose we have three clusters):

mu_1 - mu_1
mu_2 - mu_1
mu_3 - mu_1
mu_1 - mu_2
mu_2 - mu_2
mu_3 - mu_2
mu_1 - mu_3
mu_2 - mu_3
mu_3 - mu_3

The problem is that your distance vector contains zero vectors and you perform a norm operation afterwards. tf.norm gets numerical unstable since it performs a division over the length of the vector. The result is that the gradient either gets zero or inf. See this github issue.

The solution would be to remove those zero vectors in a fashion like this Stackoverflow question.

KonArtist
  • 109
  • 2
  • 9