1

Trying to put the saliency map to the image and make a new data set

trainloader = utilsxai.load_data_cifar10(batch_size=1,test=False)
testloader =  utilsxai.load_data_cifar10(batch_size=128, test=True)

this load_cifar10 is torchvision

data = trainloader.dataset.data 

trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape)

sal_maps_hf shape with (50000,32,32,3)
and trainloader shape with (50000,32,32,3)

but when I run this

for idx,img in enumerate(trainloader):
--------------------------------------------------------------------------- 
KeyError                                  Traceback (most recent call
last) ~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode)    2644             typekey = (1, 1) + shape[2:], arr["typestr"]
-> 2645             mode, rawmode = _fromarray_typemap[typekey]    2646         except KeyError:

KeyError: ((1, 1, 3), '<f4')

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last) <ipython-input-142-9410d0967245> in <module>
----> 1 show_images(trainloader)

<ipython-input-117-a32f5bd33032> in show_images(trainloader)
      1 def show_images(trainloader):
----> 2     for idx,(img,target) in enumerate(trainloader):
      3         img = img.squeeze()
      4         #pritn(img)
      5         img = torch.tensor(img)

~/venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py in
__next__(self)
    344     def __next__(self):
    345         index = self._next_index()  # may raise StopIteration
--> 346         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    347         if self._pin_memory:
    348             data = _utils.pin_memory.pin_memory(data)

~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/venv/lib/python3.7/site-packages/torchvision/datasets/cifar.py in
__getitem__(self, index)
    120         # doing this so that it is consistent with all other datasets
    121         # to return a PIL Image
--> 122         img = Image.fromarray(img)
    123 
    124         if self.transform is not None:

~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode)    2645             mode, rawmode = _fromarray_typemap[typekey]  2646         except KeyError:
-> 2647             raise TypeError("Cannot handle this data type")    2648     else:    2649         rawmode = mode

TypeError: Cannot handle this data type


    trainloader.dataset.__getitem__

<bound method CIFAR10.__getitem__ of Dataset CIFAR10
    Number of datapoints: 50000
    Root location: /mnt/3CE35B99003D727B/input/pytorch/data
    Split: Train
    StandardTransform Transform: Compose(
               Resize(size=32, interpolation=PIL.Image.BILINEAR)
               ToTensor()
           )
desertnaut
  • 57,590
  • 26
  • 140
  • 166
  • Are you sure your `dataloader` stores the data as an `nd.array` ot `torch.tensor`? It seems like your data is stored as `PIL.Image`s. – Shai Dec 30 '19 at 11:57
  • data=trainloader.dataset.data says numpy.ndarray –  Dec 30 '19 at 12:35
  • I think the type does not matter just assign the new dataset has some way to do it.. –  Dec 30 '19 at 12:37
  • yet the error you get comes from `PIL.Image`... are you sure `dataset.__getitem__` actually uses `dataset.data`? is it possible there is an additional representation of the data? you'll have to look at the code of the dataset. – Shai Dec 30 '19 at 12:37
  • trainloader.dataset.__getitem__ :: –  Dec 30 '19 at 12:39
  • what is the `dtype` of `dataset.data` **before** the change? and after? what is the `dtype` of `sal_maps_hf`? – Shai Dec 30 '19 at 12:42
  • type(trainloader.dataset.data) = numpy.ndarray :: type(trainloader.dataset) = torchvision.datasets.cifar.CIFAR10 –  Dec 30 '19 at 12:44
  • I check the both type but type is the same –  Dec 30 '19 at 12:45
  • you checked the `type` not the `dtype`: `dataset.data.dtype` and `sal_maps_hf.dtype` – Shai Dec 30 '19 at 12:46
  • BTW, why don't you [format](https://stackoverflow.com/help/formatting) code in your comments? – Shai Dec 30 '19 at 12:47
  • "trainloader = utilsxai.load_data_cifar10(batch_size=1,test=False) mask = np.random.rand(50000,32,32,3)
    trainloader.dataset.data = mask
    for idx,(img,target) in enumerate(trainloader): img = img.squeeze() "
    this will give you the same error try it
    –  Dec 30 '19 at 12:51
  • not sure how to put the next line format in comment –  Dec 30 '19 at 12:52
  • not sure it is possible in comments. but you can use "`" to indicate code – Shai Dec 30 '19 at 12:52
  • trainloader = utilsxai.load_data_cifar10(batch_size=1,test=False) testloader = utilsxai.load_data_cifar10(batch_size=128, test=True) mask = np.random.rand(50000,32,32,3) trainloader.dataset.data = mask * trainloader.dataset.data or trainloader.dataset.data = mask for idx,(img,target) in enumerate(trainloader): something " will show you the same error –  Dec 30 '19 at 12:56

1 Answers1

1

Your sal_maps_hf is not np.uint8.

Based on the partial information in the question and in comments, I guess that your mask is of dtype np.float (or similar), and by multiplying data * sal_maps_hf your data is cast to dtype other than np.uint8 which later makes PIL.Image to throw an exception.

Try:

trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape).astype(np.uint8)
Shai
  • 111,146
  • 38
  • 238
  • 371
  • 1
    omg you sloved it. how you check the dtype of sal_maps_hf type(sal_maps_hf) only shows its numpy.array –  Dec 30 '19 at 12:59
  • @jakeMonk it's **not** the [`type`](https://www.geeksforgeeks.org/python-type-function/) but rather the [`dtype`](https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html). `dtype` is the "type" of the elements of the `ndarray` (or the `tensor`). – Shai Dec 30 '19 at 13:01
  • I see , sal_maps_hf.dtype = dtype(' –  Dec 30 '19 at 13:04
  • @jakeMonk as expected. `uint8` is a very restrictive data type. Make sure the data after the casting is not corrupted beyond repair for your needs. – Shai Dec 30 '19 at 13:05
  • thank you i did not even know ' –  Dec 30 '19 at 13:07
  • 1
    I see I should know more about the data type thank you Shai –  Dec 31 '19 at 03:23