Background
Currently, I'm using an LSTM to perform a regression. I'm using small batch sizes with a reasonably large amount of timesteps (but much, much fewer than the number of timesteps I have).
I'm attempting to transition to larger batches with fewer timesteps, but with stateful enabled, to allow a larger amount of generated training data to be used.
However, I am currently using a regularization based off of sqrt(timestep), (this is ablation tested and helps with convergence speed, it works because of the statistical nature of the problem, expected error decreases by a factor of sqrt(timestep)). This is performed by using tf.range
to generate a list of the proper size within the loss function. This approach will not be correct when stateful is enabled, because it will be counting the wrong number of timesteps (number of timesteps in this batch, rather than seen so far overall).
Question
Is there a way to pass an offset or list of ints or floats to the loss function? Preferably without modifying the model, but I recognize that a hack of this nature might be required.
Code
Simplified model:
def create_model():
inputs = Input(shape=(None,input_nodes))
next_input = inputs
for i in range(dense_layers):
dense = TimeDistributed(Dense(units=dense_nodes,
activation='relu',
kernel_regularizer=l2(regularization_weight),
activity_regularizer=l2(regularization_weight)))\
(next_input)
next_input = TimeDistributed(Dropout(dropout_dense))(dense)
for i in range(lstm_layers):
prev_input = next_input
next_input = LSTM(units=lstm_nodes,
dropout=dropout_lstm,
recurrent_dropout=dropout_lstm,
kernel_regularizer=l2(regularization_weight),
recurrent_regularizer=l2(regularization_weight),
activity_regularizer=l2(regularization_weight),
stateful=True,
return_sequences=True)\
(prev_input)
next_input = add([prev_input, next_input])
outputs = TimeDistributed(Dense(output_nodes,
kernel_regularizer=l2(regularization_weight),
activity_regularizer=l2(regularization_weight)))\
(next_input)
model = Model(inputs=inputs, outputs=outputs)
Loss function
def loss_function(y_true, y_pred):
length = K.shape(y_pred)[1]
seq = K.ones(shape=(length,))
if use_sqrt_loss_scaling:
seq = tf.range(1, length+1, dtype='int32')
seq = K.sqrt(tf.cast(seq, tf.float32))
seq = K.reshape(seq, (-1, 1))
if separate_theta_phi:
angle_loss = phi_loss_weight * phi_metric(y_true, y_pred, angle_loss_fun)
angle_loss += theta_loss_weight * theta_metric(y_true, y_pred, angle_loss_fun)
else:
angle_loss = angle_loss_weight * total_angle_metric(y_true, y_pred, angle_loss_fun)
norm_loss = norm_loss_weight * norm_loss_fun(y_true, y_pred)
energy_loss = energy_loss_weight * energy_metric(y_true, y_pred)
stability_loss = stability_loss_weight * stab_loss_fun(y_true, y_pred)
act_loss = act_loss_weight * act_loss_fun(y_true, y_pred)
return K.sum(K.dot(0
+ angle_loss
+ norm_loss
+ energy_loss
+ stability_loss
+ act_loss
, seq))
(The functions that calculate the pieces of the loss function shouldn't be super duper relevant. Simply, they're also loss functions.)