My model training involves encoding multiple variants of a same image then summing the produced representation over all variants for the image.
The data loader produces tensor batches of the shape: [batch_size,num_variants,1,height,width]
.
The 1
corresponds to image color channels.
How can I train my model with minibatches in pytorch? I am looking for a proper way to forward all the batch_size×num_variant images through the network and summing the results over all groups of variants.
My current solution involves flattening the first two dimensions and doing a for loop to sum the representations, but I feel like there should be a better way an d I am not sure the gradients will remember everything.