2

I am trying to feed two images into a network and I want to do identical transform between these two images. transforms.Compose() takes one image at a time and produces output independent to each other but I want same transformation. I did my own coding for hflip() now I am interested to get the random crop. Is there any way to do that without writing custom functions?

Joy Mazumder
  • 870
  • 1
  • 8
  • 14

2 Answers2

1

I would use workaround like this - make my own crop class inherited from RandomCrop, redefining call with

    …
        if self.call_is_even :
            self.ijhw = self.get_params(img, self.size)
        i, j, h, w = self.ijhw
        self.call_is_even = not self.call_is_even

instead of

    i, j, h, w = self.get_params(img, self.size)

The idea is to suppress randomizer on odd calls

Kobap Bopy
  • 61
  • 8
Alexey Birukov
  • 1,565
  • 15
  • 22
0

PtrBlck recommends to use the functional pytorch API to make your own transform that does what you want [1], however I think in most cases there is a cleaner way:

If the image is torch Tensor, it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions
-- torchvision.RandomCrop

You can stack your images along the Channel dimension. (Maybe it even means that you can stack the images along a new dimension). That way, you can apply the transform once -- to all images at the same time.


Addendum for More Images

For two images, this works fine. For lists of images, you can do the same. But if you have nested lists of images, this gets annoying to do.
I wrote a function that operates on nested lists (of exactly depth 2) for that use-case. Note that this approach only works on tensors, not PIL images, so it first transforms the pil images to tensors unless you make it not do that -- example in the test function below:

def apply_same_transform_to_all_images(transform: torch.nn.Module, images: List[List[Image.Image]], 
                                       to_ten: torch.nn.Module = None) -> List[List[torch.Tensor]]:
    """
        applies the transform to all images in the nested list - in a single run.
        Useful e.g. for consistent RandomCrop.
        Note that this will call torchvision.ToTensor() on your images unless you set `to_ten`!
          That rescales values. It also means the passed-in transform should assume tensor inputs, not PIL inputs.
    """
    if to_ten is None: to_ten = torchvision.transforms.ToTensor()

    # traverse the list to collect all images
    all_images = list()
    for outer_idx, inner_list in enumerate(images):
        for _inner_idx, image in enumerate(inner_list):
            all_images.append(to_ten(image))
    # stack all images in a new dimension
    stacked = torch.stack(all_images, dim=0)
    transformed = transform(stacked)
    transformed_iter = iter(transformed)

    # undo the traversing in the same order.
    output_nested_list = list()
    for outer_idx, inner_list in enumerate(images):
        output_nested_list.append(list())
        for _inner_idx, _image in enumerate(inner_list):
            output_nested_list[outer_idx].append(next(transformed_iter))

    return output_nested_list


def test_apply_same_transform_to_all_images():
    import torch, utils
    from torchvision.transforms import RandomCrop, Compose
    identity_transform = Compose([]) # we already have tensors
    img1 = torch.arange(4*4*3).reshape((3,4,4)) # image of shape 4x4 with 3 channels
    img2 = img1 * 2
    crop_transform = RandomCrop((2,2))
    result = apply_same_transform_to_all_images(crop_transform, to_ten = identity_transform, images = [[img1], [img1, img2]])
    assert torch.allclose(result[0][0], result[1][0])
    assert torch.allclose(result[0][0] * 2, result[1][1])


if __name__ == "__main__":
    # just for debugging
    test_apply_same_transform_to_all_images()
lucidbrot
  • 5,378
  • 3
  • 39
  • 68