I'm trying to approximate the sine function using a neural network (Keras) and with data and the fact d/dx d/dx sin(x) = -sin(x)
. This property of the sine function is used in the custom loss function of the neural network.
My code currently looks like this
import tensorflow as tf
import numpy as np
from tensorflow import keras
from numpy import random
# --- Disable eager execution
tf.compat.v1.disable_eager_execution()
# --- Settings
x_min = 0
x_max = 2*np.pi
n_train = 64
n_test = 64
# --- Generate dataset
x_train = random.uniform(x_min, x_max, n_train)
y_train = np.sin(x_train)
x_test = random.uniform(x_min, x_max, n_test)
y_test = np.sin(x_test)
# --- Create model
model = keras.Sequential()
model.add(keras.layers.Dense(64, activation="tanh", input_dim=1))
model.add(keras.layers.Dense(64, activation="tanh"))
model.add(keras.layers.Dense(1, activation="tanh"))
def grad(input_tensor, output_tensor):
return keras.layers.Lambda( lambda z: keras.backend.gradients( z[ 0 ], z[ 1 ] ), output_shape = [1] )( [ output_tensor, input_tensor ] )
def custom_loss_wrapper(input_tensor, output_tensor):
def custom_loss(y_true, y_pred):
mse_loss = keras.losses.mean_squared_error(y_true, y_pred)
derivative_loss = keras.losses.mean_squared_error(input_tensor, -grad(input_tensor, grad(input_tensor, output_tensor))[0])
return mse_loss + derivative_loss
return custom_loss
# --- Configure learning process
model.compile(
optimizer=keras.optimizers.Adam(0.01),
loss=custom_loss_wrapper(model.input, model.output),
metrics=['MeanSquaredError'])
# --- Train from dataset
model.fit(x_train, y_train, batch_size=32, epochs=1000, validation_data=(x_test, y_test))
# --- Evaluate model
model.evaluate(x_test, y_test)
Especially important is the custom loss function. The Lambda-Definition of the derivative comes from this question. Sadly it seems like the model doesn't train correctly. The loss approaches zero and stays above 10.
Without the derivative term the network works fine, but I can't seem to find the mistake in the derivative calculation. Thank you for you help!