I have a model as such:
netF = timm.create_model(...) #feature extractor
netB = network.feat_bottlenect(...) #bottleneck layer
netC = network.feat_classifier(...) #classifier layer
output = netF(netB(netC(input)))
I want to apply torch.nn.DataParallel to these networks. I tried applying DataParallel to each individual network, as follows.
netF = torch.nn.DataParallel(netF)
netB = torch.nn.DataParallel(netB)
netC = torch.nn.DataParallel(netC)
output = netF(netB(netC(input)))
but it does not seem to work. This is the only change I made to go from single GPU to miltiple GPUs. The overall model trains fine on a single GPU without DataParallel.
What am I doing wrong? Thank you.
I am expecting the model to train as it would without DataParallel.