0

I have code for multiclass segmentation using PyTorch. The inputs are images and their ground truth masks. This a piece of my code:

criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
    net.train()
    epoch_loss = 0
    for batch in train_loader:
        true_masks = batch['mask']
        imgs=batch['image']
        imgs = imgs.to(device=device, dtype=torch.float32)
        mask_type = torch.float32 if net.n_classes == 1 else torch.long
        true_masks = true_masks.to(device=device, dtype=mask_type)
        masks_pred = net(imgs)
        loss = criterion(masks_pred, true_masks)
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(net.parameters(), 0.1)
        optimizer.step()
        pbar.update(imgs.shape[0])
        global_step += 1

Now I want to know for every epoch what is each class's weight. I have 6 classes: [0, 1, 2, 3, 4, 5]. For example, I want to get information like these:

class 1 weight=...

class 2 weight=...

class 3 weight=...

class 4 weight=...

...

how can I get and print these weights?

Babak Azad
  • 60
  • 1
  • 7

0 Answers0