I want to perform k-fold CV, but in my past approach, the augmentations where leaking into the validation dataset. For this, I am using the WrapperDataset class, I found in this post: Augmenting only the training set in K-folds cross validation. However, I had to adjust it a bit, due to how I prepare my data:
# Collect images and labels
images = glob.glob(os.path.join(data_dir, 'images', '*.nii.gz'))
labels = glob.glob(os.path.join(data_dir, 'labels', '*.nii.gz'))
# Combine images and labels into dictionaries
files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(images, labels)]
# set deterministic training for reproducibility
set_determinism(seed=0)
# Define the data transforms and augmentations (random rotation and flips)
# TODO: Check bone window using ScaleIntensityRanged()
train_transforms = Compose(
[
LoadImaged(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']), # Add a channel as 1st element
RandAffined(
keys=['image', 'label'],
mode=('bilinear', 'nearest'),
prob=0.5,
rotate_range=(0.174533, 0.174533), # +-10° in radians
shear_range=(0.2, 0.5)
),
RandFlipd(keys=["image", "label"],
prob=0.5,
spatial_axis=(0, 1)
),
RandRotate90d(keys=["image", "label"],
prob=0.5,
max_k=1,
spatial_axes=(0, 1)
),
ToTensord(keys=['image', 'label'])
]
)
valid_transforms = Compose(
[
LoadImaged(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ToTensord(keys=['image', 'label'])
]
)
# Create DataLoader objects, either using cached or non-cached Datasets
ds = CacheDataset(data=files, cache_rate=1.0)
The updated WrapperDataset class looks now like this:
class WrapperDataset:
def __init__(self, dataset, transform=None, target_transform=None):
self.dataset = dataset
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
data = self.dataset[index]
image, label = data['image'], data['label']
if self.transform is not None:
data = {'image': self.transform(image), 'label': label}
if self.target_transform is not None:
data = {'image': image, 'label': self.target_transform(label)}
return data
def __len__(self):
return len(self.dataset)
However, when I call my DataLoaders like this:
train_loader = torch.utils.data.DataLoader(
WrapperDataset(dataset, transform=train_transforms),
batch_size=4, sampler=train_subsampler)
valid_loader = torch.utils.data.DataLoader(
WrapperDataset(dataset, transform=valid_transforms),
batch_size=1, sampler=valid_subsampler)
I get a lot of errors :D
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\transform.py:141, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
140 return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
--> 141 return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142 except Exception as e:
143 # if in debug mode, don't swallow exception so that the breakpoint
144 # appears where the exception was raised.
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\transform.py:98, in _apply_transform(transform, data, unpack_parameters, lazy, overrides, logger_name)
96 return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
---> 98 return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\io\dictionary.py:162, in LoadImaged.__call__(self, data, reader)
157 """
158 Raises:
159 KeyError: When not ``self.overwriting`` and key already exists in ``data``.
160
161 """
--> 162 d = dict(data)
163 for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
ValueError: dictionary update sequence element #0 has length 1; 2 is required
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[10], line 57
54 labelsTr = []
55 outputsTr = []
---> 57 for batch_data in train_loader:
59 train_step += 1
61 print(f"Train Batch: {train_step}/{len(train_loader)}")
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
675 def _next_data(self):
676 index = self._next_index() # may raise StopIteration
--> 677 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
678 if self._pin_memory:
679 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\_utils\fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\_utils\fetch.py:51, in (.0)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
Cell In[3], line 11, in WrapperDataset.__getitem__(self, index)
9 image, label = data['image'], data['label']
10 if self.transform is not None:
---> 11 data = {'image': self.transform(image), 'label': label}
12 if self.target_transform is not None:
13 data = {'image': image, 'label': self.target_transform(label)}
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\compose.py:322, in Compose.__call__(self, input_, start, end, threading, lazy)
321 def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):
--> 322 result = execute_compose(
323 input_,
324 transforms=self.transforms,
325 start=start,
326 end=end,
327 map_items=self.map_items,
328 unpack_items=self.unpack_items,
329 lazy=self.lazy,
330 overrides=self.overrides,
331 threading=threading,
332 log_stats=self.log_stats,
333 )
335 return result
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\compose.py:111, in execute_compose(data, transforms, map_items, unpack_items, start, end, lazy, overrides, threading, log_stats)
109 if threading:
110 _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
--> 111 data = apply_transform(
112 _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats
113 )
114 data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)
115 return data
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\transform.py:171, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
169 else:
170 _log_stats(data=data)
--> 171 raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform
My first idea for a reason, is that I use the transforms from MONAI. But, they should be wrapped with Pytorch, so I guess it is not it...
I already tried to use a concatenated dataset, where I apply the split and transformatins beforehand:
# Split into {test_size} % validation set
train_images, valid_images, train_labels, valid_labels =
train_test_split(images,
labels,
test_size=0.2,
shuffle=False)
# Combine images and labels into dictionaries
train_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
valid_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(valid_images, valid_labels)]`
# Create DataLoader objects, either using cached or non-cached Datasets
train_ds = CacheDataset(data=train_files, transform=train_transforms,cache_rate=1.0)
valid_ds = CacheDataset(data=valid_files, transform=valid_transforms,cache_rate=1.0)
dataset = ConcatDataset([train_ds, valid_ds])
return dataset
But it did not do the trick.