0

I work with PySyft, I would like to train a ResNet50 model Here is a part of my code:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
     
        self.model = models.resnet50(pretrained=False)
        self.fc1 = nn.Linear(2048,2048)
        self.fc2 = nn.Linear(2048, 3)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):

        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
...
        x = self.model.avgpool(x)
        x = x.view(-1,2048*1*1)
        x = nn.functional.relu(self.fc1(x))
        x = self.dropout(x)
      
        x = nn.functional.softmax(self.fc2(x), dim=1)
        return x

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
        model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

The error is in The model.get() self.set_(tensor.native_type(self.dtype)) : RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_set_

i can't understand why in function train() below, when I add a pretrained model into a new model class, it cannot get model back

seni
  • 659
  • 1
  • 8
  • 20

1 Answers1

0

I think you need to change your runtime type to GPU. It seems you are using the CPU instead.

Charuka Herath
  • 328
  • 4
  • 20