0

My model training involves encoding multiple variants of a same image then summing the produced representation over all variants for the image.

The data loader produces tensor batches of the shape: [batch_size,num_variants,1,height,width]. The 1 corresponds to image color channels.

How can I train my model with minibatches in pytorch? I am looking for a proper way to forward all the batch_size×num_variant images through the network and summing the results over all groups of variants.

My current solution involves flattening the first two dimensions and doing a for loop to sum the representations, but I feel like there should be a better way an d I am not sure the gradients will remember everything.

hamza keurti
  • 385
  • 1
  • 2
  • 15
  • Have you considered 3D Convolution models? They're designed for such approaches. You could consider the no. of variants as `depth` in a 3D Convolution. An input to `nn.Conv3d` is of the form `batch_size*channels*depth*height*width`. – planet_pluto Nov 11 '20 at 17:07
  • @planet_pluto I do not think 3D convolutions are what I'm looking for. Operations on images are shared over variants. It is the same model operating on all variants of the image. – hamza keurti Nov 12 '20 at 11:10

1 Answers1

1

Not sure I understood you correctly, but I guess this is what you want (say the batched image tensor is called image):

Nb, Nv, inC, inH, inW = image.shape

# treat each variant as if it's an ordinary image in the batch
image = image.reshape(Nb*Nv, inC, inH, inW)

output = model(image)
_, outC, outH, outW = output.shape[1]

# reshapes the output such that dim==1 indicates variants
output = output.reshape(Nb, Nv, outC, outH, outW)

# summing over the variants and lose the dimension of summation, [Nb, outC, outH, outW]
output = output.sum(dim=1, keepdim=False)

I used inC, outC, inH, etc. in case the input and output channels/sizes are different.

ihdv
  • 1,927
  • 2
  • 13
  • 29