I'm implementing an MLP with Keras
and a custom loss function.
I notice model.compile()
takes very much time: it seems doesn't end.
The loss that I passed to the compile()
function is custom.
I'm also using another function that is used in the loss function.
This is my custom loss:
def get_top_one_probability(vector):
return (K.exp(vector) / K.sum(K.exp(vector)))
def custom_loss(groups_id_count, tf_session):
def listnet_loss(real_labels, predicted_labels):
losses = tf.Variable([[0.0]], tf.float32)
for group in groups_id_count:
start_range = 0
end_range = (start_range + group[1])
batch_real_labels = real_labels[start_range:end_range]
batch_predicted_labels = predicted_labels[start_range:end_range]
loss = -K.sum(get_top_one_probability(batch_real_labels)) * tf.math.log(get_top_one_probability(batch_predicted_labels))
losses = tf.concat([losses, loss], axis=0)
start_range = end_range
return K.mean(losses)
return listnet_loss
And this is the MLP code:
mlp = keras.models.Sequential()
# add input layer
mlp.add(
keras.layers.Dense(
units=training_dataset.shape[1],
input_shape = (training_dataset.shape[1], ),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation='tanh')
)
# add hidden layer
mlp.add(
keras.layers.Dense(
units=training_dataset.shape[1] + 10,
input_shape = (training_dataset.shape[1] + 10,),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation='relu')
)
# add output layer
mlp.add(
keras.layers.Dense(
units=1,
input_shape = (1, ),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation='softmax')
)
# define SGD optimizer
sgd_optimizer = keras.optimizers.SGD(
lr=0.01, decay=0.01, momentum=0.9, nesterov=True
)
# compile model
print('Compiling model...\n')
mlp.compile(
optimizer=sgd_optimizer,
loss=custom_loss(groups_id_count, tf.compat.v1.Session())
)
mlp.summary() # print model settings
# Training
with tf.device('/GPU:0'):
print('Start training')
mlp.fit(training_dataset, training_dataset_labels, epochs=50, verbose=2, batch_size=training_dataset.shape[0], workers=10)
Why the compile()
function takes very very much time? Thanks in advance