0

I'm working with 3D CT medical data, and I'm trying to slice it into 2D slices that I can input into a UNet model.

I've loaded the data into a torch dataloader, and each iteration currently produces a 4D tensor:

for batch_index, batch_samples in enumerate(train_loader):
    data, target = batch_samples['image'].float().cuda(), batch_samples['label'].float().cuda()
    print(data.size())
torch.Size([1, 333, 512, 512])
torch.Size([1, 356, 512, 512])

such as this one. I want to iterate over the 333 slices, and then the 356 slices, such that the model receives torch sizes [1, 1, 512, 512] each time.

I was hoping something like :

for x in (data[:,x,:,:]):

would work but it says I need to define x first. How can I iterate over a specific dimension in a torch tensor?

ml-yeung
  • 3
  • 2

1 Answers1

1

Simply specify the dimension:

for i in range(data.shape[1]):  # dim=1
    x = data[:, i, :, :]
    # [...]

If you really need that extra dimension, simply add .unsqueeze():

d = 1
for i in range(data.shape[d]):         # dim=1
    x = data[:, i, :, :].unsqueeze(d)  # same dim=1
    # [...]
Berriel
  • 12,659
  • 4
  • 43
  • 67