You could write a custom loss function and temporarily replace missing values with zeroes. Then after calculating cross entropy loss replace loss values in places in which the label was missing with zeroes.
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()
def missing_values_cross_entropy_loss(y_true, y_pred):
# We're adding a small epsilon value to prevent computing logarithm of 0 (consider y_hat == 0.0 or y_hat == 1.0).
epsilon = tf.constant(1.0e-30, dtype=np.float32)
# Check that there are no NaN values in predictions (neural network shouldn't output NaNs).
y_pred = tf.debugging.assert_all_finite(y_pred, 'y_pred contains NaN')
# Temporarily replace missing values with zeroes, storing the missing values mask for later.
y_true_not_nan_mask = tf.logical_not(tf.math.is_nan(y_true))
y_true_nan_replaced = tf.where(tf.math.is_nan(y_true), tf.zeros_like(y_true), y_true)
# Cross entropy, but split into multiple lines for readability:
# y * log(y_hat)
positive_predictions_cross_entropy = y_true_nan_replaced * tf.math.log(y_pred + epsilon)
# (1 - y) * log(1 - y_hat)
negative_predictions_cross_entropy = (1.0 - y_true_nan_replaced) * tf.math.log(1.0 - y_pred + epsilon)
# c(y, y_hat) = -(y * log(y_hat) + (1 - y) * log(1 - y_hat))
cross_entropy_loss = -(positive_predictions_cross_entropy + negative_predictions_cross_entropy)
# Use the missing values mask for replacing loss values in places in which the label was missing with zeroes.
# (y_true_not_nan_mask is a boolean which when casted to float will take values of 0.0 or 1.0)
cross_entropy_loss_discarded_nan_labels = cross_entropy_loss * tf.cast(y_true_not_nan_mask, tf.float32)
mean_loss_per_row = tf.reduce_mean(cross_entropy_loss_discarded_nan_labels, axis=1)
mean_loss = tf.reduce_mean(mean_loss_per_row)
return mean_loss
y_true = tf.constant([
[0, 1, np.nan, 0],
[0, 1, 1, 0],
[np.nan, 1, np.nan, 0],
[1, 1, 0, np.nan],
])
y_pred = tf.constant([
[0.1, 0.7, 0.1, 0.3],
[0.2, 0.6, 0.1, 0],
[0.1, 0.9, 0.3, 0.2],
[0.1, 0.4, 0.4, 0.2],
])
loss = weighted_cross_entropy_loss(y_true, y_pred)
# Extract value from EagerTensor.
print(loss.numpy())
outputs:
0.4945919
Use the loss function when compiling the keras model as specified in documentation:
model.compile(loss=missing_values_cross_entropy_loss, optimizer='sgd')