0

I have a model as such:

netF = timm.create_model(...) #feature extractor
netB = network.feat_bottlenect(...) #bottleneck layer
netC = network.feat_classifier(...) #classifier layer
output = netF(netB(netC(input)))

I want to apply torch.nn.DataParallel to these networks. I tried applying DataParallel to each individual network, as follows.

netF = torch.nn.DataParallel(netF)
netB = torch.nn.DataParallel(netB)
netC = torch.nn.DataParallel(netC)
output = netF(netB(netC(input)))

but it does not seem to work. This is the only change I made to go from single GPU to miltiple GPUs. The overall model trains fine on a single GPU without DataParallel.

What am I doing wrong? Thank you.

I am expecting the model to train as it would without DataParallel.

0 Answers0