I am trying to fit a keras model in which my output variable is always positive. I want to use a gamma distribution to model this problem. The problem is that the loss always ouptputs NAN.
I built the following keras model:
model_max = tf.keras.Sequential([
tf.keras.layers.Dense(20,input_dim=10, activation="relu"),
tf.keras.layers.Dense(15,activation="relu"),
tf.keras.layers.Dense(10,activation="relu"),
tf.keras.layers.Dense(5,activation="relu"),
tf.keras.layers.Dense(2),
tfp.layers.DistributionLambda(lambda t:
tfd.Gamma(concentration = tf.math.softplus(0.005*t[...,:1])+0.001,
rate = tf.math.softplus(0.005*t[...,1:])+0.001)
),
])
Notice that I used softplus because both arguments of the distribution must be positive. Also I added 0.001 to make sure the arguments are always greater than zero.
My loss function is as follows:
def gamma_loss(y_true, my_dist):
dist_mean = my_dist.mean()
dist_stddev = my_dist.stddev()
alpha = (dist_mean / dist_stddev)**2
beta = dist_mean / dist_stddev**2
gamma_distr = tfd.Gamma(concentration=alpha, rate=beta)
return -tf.reduce_mean(gamma_distr.log_prob(y_true))
This function seems to work fine. For example, if I run the following code it runs fine:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def gamma_loss(y_true, my_dist):
dist_mean = my_dist.mean()
dist_stddev = my_dist.stddev()
alpha = (dist_mean / dist_stddev)**2
beta = dist_mean / dist_stddev**2
#print(alpha)
gamma_distr = tfd.Gamma(concentration=alpha, rate=beta)
return -tf.reduce_mean(gamma_distr.log_prob(y_true)).numpy()
dist = tfd.Gamma(1,1)
gamma_loss(100, dist)
However, if I compile it with the following line:
model_max.compile(optimizer=tf.optimizers.Adam(learning_rate = 0.001),loss=gamma_loss)
The loss always outputs nan
What am I doing wrong? I have tried different froms of the loss funcion but nothing seems to work. I think it is realted to the concentration argument since I already have a similar model to this working with a normal distribution. In that model, I did not use softplus for the mean (loc) because that distribution accepts any positive or negative value. I used the exact structure for the standard deviation as it must also be possitive in the Normal distribution. It works just fine. Why it doesn't work for the Gamma Distribution?
Thank you in advice to anyone who can help me understand what I'm doing wrong.