I am building off of MONAI's 3D segmentation tutorial to work with 4D NIfTI
data, where the fourth dimension represents the channels to be inputted for the proposed 3D network. I have adapted the tutorial to better segment with MONAI's DynUNet(nnUNet)
, but am facing trouble correctly transforming the data into the desired format to train my 3D network in multichannel.
My current approach seems to cause the previously-working DynUNet to get stuck while loading data (estimated time 12+ hrs to load and was Killed by server, was ~1 min previously). I am unable to find if I am transforming/preparing the data correctly for 3D multichannel training.
The current input dimension looks like [num_px_x, num_px,y, num_slices, num_channels]
, and I hope to transform it into a 3D volume able to be used for a multichannel network.
If helpful, the 4th dimension is of length 7, where index 0 represents an intensity value and indices 1-6 represent a one-hot encoded sequence.
A snippet of my function to get transform:
def get_xforms(mode="train", keys=("image", "label")):
"""returns a composed transform for train/val/infer."""
xforms = [
LoadImaged(keys),
EnsureChannelFirstd(keys='image'),
AsChannelFirstd(keys),
Orientationd(keys, axcodes="LPS"),
]
return monai.transforms.Compose(xforms)
Training data loader
# format: [ {'image': 'ct file path', 'label': 'seg file path'} ]
train_files = [{keys[0]: img, keys[1]: seg} for img, seg in zip(images[:n_train], labels[:n_train])]
val_files = [{keys[0]: img, keys[1]: seg} for img, seg in zip(images[-n_val:], labels[-n_val:])]
keys = ("image", "label")
batch_size = 2
train_transforms = get_xforms("train", keys)
train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms)
train_loader = monai.data.DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=torch.cuda.is_available(),
)
Network function
def get_net():
kernels=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]
strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
num_classes = 2
net = monai.networks.nets.DynUNet(
spatial_dims=3,
in_channels=7,
out_channels=num_classes,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
)
return net
Other code segments are mostly consistent with MONAI's tutorial