I am training a linear tensorflow model with a custom additional loss which depends on some inner derivatives. The NN recieves an input (x,y) and gives the output (u, v) with some hidden layers in between. I use implicit differentiation to add the u_x (derivative of u with respect to x) and v_y (analagously) to the loss of the function.
I made this function
def continuous_equation_constraint(features):
"""Calculate the constraint based on the current prediction of the nueral network.
It should one should ideally always have u_x + v_y = 0.
Args:
x (tf.Tensor): Position data
y (tf.Tensor): Time data
Returns:
float value: The value given by the constraint equation for the given values of
x and y and the current state of the network.
"""
features = features
with tf.GradientTape(persistent=False) as tape:
tape.watch(features)
prediction = model(features)
grads = tape.jacobian(prediction, features)
# n = grads.numpy()
u_x = tf.constant([grads[i,0,i,0].numpy() for i in range(len(features))], dtype=tf.float32)
v_y = tf.constant([grads[i,1,i,1].numpy() for i in range(len(features))], dtype=tf.float32)
constraint = tf.add(u_x, v_y) # This should be equal to 0
return constraint
which works fine and I can train my netwrok. This however didnt allow me to use @tf.function due to the use of numpy. Therefore I wrote the following function:
def continuous_equation_constraint(features):
features = features
with tf.GradientTape(persistent=False) as tape:
tape.watch(features)
prediction = model(features)
grads = tape.jacobian(prediction, features)
l = len(features)
features_tensor = tf.constant([i for i in range(l)])
# Create a tensor containing the first and third dimension indices
# with an additional axis to match the required format for 'tf.gather_nd'.
indices = tf.stack([features_tensor, tf.zeros_like(features_tensor),
features_tensor, tf.zeros_like(features_tensor)], axis=1)
# Gather the required elements from 'grads'
u_x = tf.gather_nd(grads, indices)
indices = tf.stack([features_tensor, tf.ones_like(features_tensor),
features_tensor, tf.ones_like(features_tensor)], axis=1)
v_y = tf.gather_nd(grads, indices)
# Convert 'u_x' to float32 (if needed)
u_x = tf.cast(u_x, tf.float32)
v_y = tf.cast(v_y, tf.float32)
constraint = tf.add(u_x, v_y) # This should be equal to 0
return constraint
Which uses only tensorflow operations, allowing the use of tf.function.
Now however with this new function the network ist training properly. I compared the outputs of the functions (the losses) and they are always the same, but the calculated gradients are different... the gradients are applied as shown below:
with tf.GradientTape(persistent=True) as tape2:
prediction = model(features)
model_loss = mse(labels, prediction)
constraint_loss = tf.reduce_mean(tf.abs(continuous_equation_constraint(features)))
complete_loss = tf.add(model_loss, constraint_loss)
gradients = tape2.gradient(complete_loss, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
I want to use the tf.funciton to improve the training speed but for some reason when I do, the traing loss goes down super fast and doesnt really learn anything.
Thanks for the help!
I made the following function which again has the same results for all losses but still gives a different gradient...
def continuous_equation_constraint3(features):
"""Calculate the constraint based on the current prediction of the nueral network.
It should one should ideally always have u_x + v_y = 0.
Args:
x (tf.Tensor): Position data
y (tf.Tensor): Time data
Returns:
float value: The value given by the constraint equation for the given values of
x and y and the current state of the network.
"""
features = features
x = tf.cast(features[:,0], tf.float32)
y = tf.cast(features[:,1], tf.float32)
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
tape.watch(y)
x = tf.reshape(x, [-1, 1])
y = tf.reshape(y, [-1, 1])
pos = tf.concat([x,y], axis=1)
prediction = model(pos)
u = prediction[:,0]
v = prediction[:,1]
u_x = tape.gradient(u,x)
v_y = tape.gradient(v,y)
constraint = tf.add(u_x, v_y) # This should be equal to 0
return constraint