2

I'm working on semantic image segmentation, with U-net based models. The input images are of different dimensions (between 300-600 pixels in each axis). My approach so far was to rescale the images to standard dims, and work from there.

Now I want to try a sliding window approach, extracting eg 64x64 patches from the original images (no rescaling), and train a model on that. I'm not sure about how to implement this efficiently.

For the training phase, I already have an online augmentation object (keras Sequence) for random transforms. Should I add a patch extraction process in there? If I do that, I'll be slicing numpy arrays and yielding them, and it doesn't sound very efficient. Is there a better way to do this?

And for the predicting phase, again - should I extract patches from the images in numpy, and feed them to the model? If I choose overlapping windows (eg patch dims 64x64 and strides 32x32), should I manually (in numpy) weight/average/concat the raw patch predictions from the model, to output a full-scale segmentation? Or is there a better way to handle this?

I'm using TF 2.1 btw. Any help is appreciated.

bogota999
  • 21
  • 3

1 Answers1

1

Although it might sounds like it is not efficient to break an image into smaller patches before training your model, it has one huge benefit. Before training begins, the optimizer shuffles all data samples which in turns lead to a less-biased model. However, if you feed your model one image by one, and then the optimizer breaks it into smaller patches, it is still being trained over patches of a single image.

In order to efficiently break an image into small patches, you can use:

skimage.util.view_as_windows(arr_in, window_shape, step=1)

You can define window shape and step of the rolling window. For example:

>>> import numpy as np
>>> from skimage.util.shape import view_as_windows
>>> A = np.arange(4*4).reshape(4,4)
>>> A
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])
>>> window_shape = (2, 2)
>>> B = view_as_windows(A, window_shape)
>>> B[0, 0]
array([[0, 1],
       [4, 5]])
>>> B[0, 1]
array([[1, 2],
       [5, 6]])
aminrd
  • 4,300
  • 4
  • 23
  • 45