I have code for multiclass segmentation using PyTorch. The inputs are images and their ground truth masks. This a piece of my code:
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
for batch in train_loader:
true_masks = batch['mask']
imgs=batch['image']
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
Now I want to know for every epoch what is each class's weight. I have 6 classes: [0, 1, 2, 3, 4, 5]. For example, I want to get information like these:
class 1 weight=...
class 2 weight=...
class 3 weight=...
class 4 weight=...
...
how can I get and print these weights?