5

It looks like using torch.nn.DataParallel changes the output size. Though in official docs https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel all information about size changing is as follows:

When module returns a scalar (i.e., 0-dimensional tensor) in forward(), this wrapper will return a vector of length equal to number of devices used in data parallelism, containing the result from each device.

My module returns tensor of 10 coordinates, and I have 2 GPU's where I want to run the code. The last layer of my CNN is nn.Linear(500, 10).

import torch
import torch.nn as nn

net = LeNet()    #CNN-class written above
device = torch.device("cuda:0")
net.to(device)
net = nn.DataParallel(net)

#skipped some code, where inputs and targets are loaded from files

output = net(input)
criterion = nn.SmoothL1Loss()
loss = criterion(output, target)

Note that without calling DataParallel this piece of code works okay. With DataParallel the runtime error occurs when trying to calculate loss.

RuntimeError: The size of tensor a (20) must match the size of tensor b (10) at non-singleton dimension 0

Seems like output size for each GPU separately is 10 as stated, but afterwards the two outputs are joined and that's where size 20 come from.

When changing output size in the CNN-class from 10 to 5 it started to work again, but I'm not sure it's right solution and the CNN will work properly.

  • what is `input.shape`? To the best of my knowledge DataParallel splits the data in the first dimension then concatenates the outputs from each gpu along the first dimension before returning. – jodag Aug 22 '19 at 17:26
  • I am facing the same issue, do you know how to solve it? thx! – pyxies Dec 05 '20 at 15:07

1 Answers1

0

Easiest solution to this problem is to take mean of losses from all GPUs before doing backward. This might be problematic when different GPUs have different number of samples, but will work for your case:

loss = criterion(output, target)
loss = loss.mean()
loss.backward()
asymptote
  • 1,133
  • 8
  • 15