I'm working with 3D CT medical data, and I'm trying to slice it into 2D slices that I can input into a UNet model.
I've loaded the data into a torch dataloader, and each iteration currently produces a 4D tensor:
for batch_index, batch_samples in enumerate(train_loader):
data, target = batch_samples['image'].float().cuda(), batch_samples['label'].float().cuda()
print(data.size())
torch.Size([1, 333, 512, 512])
torch.Size([1, 356, 512, 512])
such as this one. I want to iterate over the 333 slices, and then the 356 slices, such that the model receives torch sizes [1, 1, 512, 512] each time.
I was hoping something like :
for x in (data[:,x,:,:]):
would work but it says I need to define x first. How can I iterate over a specific dimension in a torch tensor?