0

I want to perform k-fold CV, but in my past approach, the augmentations where leaking into the validation dataset. For this, I am using the WrapperDataset class, I found in this post: Augmenting only the training set in K-folds cross validation. However, I had to adjust it a bit, due to how I prepare my data:

# Collect images and labels
images = glob.glob(os.path.join(data_dir, 'images', '*.nii.gz'))
labels = glob.glob(os.path.join(data_dir, 'labels', '*.nii.gz'))

# Combine images and labels into dictionaries
files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(images, labels)]

# set deterministic training for reproducibility
set_determinism(seed=0)

# Define the data transforms and augmentations (random rotation and flips) 
# TODO: Check bone window using ScaleIntensityRanged()
train_transforms = Compose(
    [
    LoadImaged(keys=['image', 'label']),
    AddChanneld(keys=['image', 'label']), # Add a channel as 1st element
    RandAffined(
        keys=['image', 'label'],
        mode=('bilinear', 'nearest'),
        prob=0.5,
        rotate_range=(0.174533, 0.174533), # +-10° in radians
        shear_range=(0.2, 0.5)
        ),
    RandFlipd(keys=["image", "label"], 
        prob=0.5, 
        spatial_axis=(0, 1)
        ),
    RandRotate90d(keys=["image", "label"], 
        prob=0.5, 
        max_k=1,
        spatial_axes=(0, 1)
        ),
    ToTensord(keys=['image', 'label'])
    ]
)

valid_transforms = Compose(
    [
    LoadImaged(keys=['image', 'label']), 
    AddChanneld(keys=['image', 'label']), 
    ToTensord(keys=['image', 'label'])
    ]
)

# Create DataLoader objects, either using cached or non-cached Datasets
ds = CacheDataset(data=files, cache_rate=1.0)

The updated WrapperDataset class looks now like this:

class WrapperDataset:
def __init__(self, dataset, transform=None, target_transform=None):
    self.dataset = dataset
    self.transform = transform
    self.target_transform = target_transform

def __getitem__(self, index):
    data = self.dataset[index]
    image, label = data['image'], data['label']
    if self.transform is not None:
        data = {'image': self.transform(image), 'label': label}
    if self.target_transform is not None:
        data = {'image': image, 'label': self.target_transform(label)}
    return data

def __len__(self):
    return len(self.dataset)

However, when I call my DataLoaders like this:

train_loader = torch.utils.data.DataLoader(
    WrapperDataset(dataset, transform=train_transforms),
    batch_size=4, sampler=train_subsampler)
valid_loader = torch.utils.data.DataLoader(
    WrapperDataset(dataset, transform=valid_transforms),
    batch_size=1, sampler=valid_subsampler)

I get a lot of errors :D

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\transform.py:141, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
    140         return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
--> 141     return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
    142 except Exception as e:
    143     # if in debug mode, don't swallow exception so that the breakpoint
    144     # appears where the exception was raised.

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\transform.py:98, in _apply_transform(transform, data, unpack_parameters, lazy, overrides, logger_name)
     96     return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
---> 98 return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\io\dictionary.py:162, in LoadImaged.__call__(self, data, reader)
    157 """
    158 Raises:
    159     KeyError: When not ``self.overwriting`` and key already exists in ``data``.
    160 
    161 """
--> 162 d = dict(data)
    163 for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):

ValueError: dictionary update sequence element #0 has length 1; 2 is required

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[10], line 57
     54 labelsTr = []
     55 outputsTr = []
---> 57 for batch_data in train_loader:
     59     train_step += 1
     61     print(f"Train Batch: {train_step}/{len(train_loader)}")

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
    675 def _next_data(self):
    676     index = self._next_index()  # may raise StopIteration
--> 677     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    678     if self._pin_memory:
    679         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\_utils\fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\torch\utils\data\_utils\fetch.py:51, in (.0)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

Cell In[3], line 11, in WrapperDataset.__getitem__(self, index)
      9 image, label = data['image'], data['label']
     10 if self.transform is not None:
---> 11     data = {'image': self.transform(image), 'label': label}
     12 if self.target_transform is not None:
     13     data = {'image': image, 'label': self.target_transform(label)}

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\compose.py:322, in Compose.__call__(self, input_, start, end, threading, lazy)
    321 def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):
--> 322     result = execute_compose(
    323         input_,
    324         transforms=self.transforms,
    325         start=start,
    326         end=end,
    327         map_items=self.map_items,
    328         unpack_items=self.unpack_items,
    329         lazy=self.lazy,
    330         overrides=self.overrides,
    331         threading=threading,
    332         log_stats=self.log_stats,
    333     )
    335     return result

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\compose.py:111, in execute_compose(data, transforms, map_items, unpack_items, start, end, lazy, overrides, threading, log_stats)
    109     if threading:
    110         _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
--> 111     data = apply_transform(
    112         _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats
    113     )
    114 data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)
    115 return data

File c:\Users\BKeoh\AppData\Local\anaconda3\envs\hsl\Lib\site-packages\monai\transforms\transform.py:171, in apply_transform(transform, data, map_items, unpack_items, log_stats, lazy, overrides)
    169     else:
    170         _log_stats(data=data)
--> 171 raise RuntimeError(f"applying transform {transform}") from e

RuntimeError: applying transform 

My first idea for a reason, is that I use the transforms from MONAI. But, they should be wrapped with Pytorch, so I guess it is not it...

I already tried to use a concatenated dataset, where I apply the split and transformatins beforehand:

# Split into {test_size} % validation set
train_images, valid_images, train_labels, valid_labels = 
train_test_split(images,                                                                                 
labels,                                                                                
test_size=0.2,                                                                                  
shuffle=False)
    
# Combine images and labels into dictionaries
train_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images, train_labels)]
valid_files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(valid_images, valid_labels)]`
    
# Create DataLoader objects, either using cached or non-cached Datasets
train_ds = CacheDataset(data=train_files, transform=train_transforms,cache_rate=1.0)
valid_ds = CacheDataset(data=valid_files, transform=valid_transforms,cache_rate=1.0)
dataset = ConcatDataset([train_ds, valid_ds])
    
return dataset

But it did not do the trick.

b_k_
  • 1
  • 1

0 Answers0