I am trying to implement Federated Learning using Flower framework in python. I get the following error when I start the process.Snapshot of the error
Here is what I tried,
NUM_CLIENTS = 10
#function to load data
def load_datasets(num_clients: int, train_loader, test_loader):
# Split training set into `num_clients` partitions to simulate different local datasets
partition_size = len(train_loader) // num_clients
lengths = [partition_size] * num_clients
datasets = random_split(train_loader, lengths, torch.Generator().manual_seed(42))
# Split each partition into train/val and create DataLoader
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10 % validation set
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(test_loader, batch_size=32)
return trainloaders, valloaders, testloader
trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS ,train_loader,test_loader)
#client function thats been passed to start the server
def client_fn(cid) -> CardiacClient:
net = CardiacModel().to(DEVICE)
trainloader = trainloaders[cid]
valloader = valloaders[cid]
return CardiacClient(cid, net, trainloader, valloader)
In the above code cid refers to the clientID,