Recently, I'm working on a project "predicting future trajectories of objects from their past trajectories by using LSTMs in Tensorflow." (Here, a trajectory means a sequence of 2D positions.)
Input to the LSTM is, of course, 'past trajectories' and output is 'future trajectories'.
The size of mini-batch is fixed when training. However, the number of past trajectories in a mini-batch can be different. For example, let the mini-batch size be 10. If I have only 4 past trajectories for the current training iteration, 6 out of 10 in the mini-batch is padded with zero value.
When calculating the loss for the back-propagation, I let the loss from the 6 be zero so that the only 4 contribute to the back-propagation.
The problem that I concern is..it seems that Tensorflow still calculates gradients for the 6 even if their loss is zero. As a result, the training speed becomes slower as I increase the mini-batch size even if I used the same training data.
I also used tf.where function when calculating the loss. However, the training time does not decrease.
How can I reduce the training time?
Here I attached my pseudo code for training.
# For each frame in a sequence
for f in range(pred_length):
# For each element in a batch
for b in range(batch_size):
with tf.variable_scope("rnnlm") as scope:
if (f > 0 or b > 0):
scope.reuse_variables()
# for each pedestrian in an element
for p in range(MNP):
# ground-truth position
cur_gt_pose = ...
# loss mask
loss_mask_ped = ... # '1' or '0'
# go through RNN decoder
output_states_dec_list[b][p], zero_states_dec_list[b][p] = cell_dec(cur_embed_frm_dec,
zero_states_dec_list[b][p])
# fully connected layer for output
cur_pred_pose_dec = tf.nn.xw_plus_b(output_states_dec_list[b][p], output_wd, output_bd)
# go through embedding function for the next input
prev_embed_frms_dec_list[b][p] = tf.reshape(tf.nn.relu(tf.nn.xw_plus_b(cur_pred_pose_dec, embedding_wd, embedding_bd)), shape=(1, rnn_size))
# calculate MSE loss
mse_loss = tf.reduce_sum(tf.pow(tf.subtract(cur_pred_pose_dec, cur_gt_pose_dec), 2.0))
# only valid ped's traj contributes to the loss
self.loss += tf.multiply(mse_loss, loss_mask_ped)