7

I'd like to randomly rotate an image tensor (B, C, H, W) around it's center (2d rotation I think?). I would like to avoid using NumPy and Kornia, so that I basically only need to import from the torch module. I'm also not using torchvision.transforms, because I need it to be autograd compatible. Essentially I'm trying to create an autograd compatible version of torchvision.transforms.RandomRotation() for visualization techniques like DeepDream (so I need to avoid artifacts as much as possible).

import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


# Somehow rotate tensor around it's center
def rotate_tensor(tensor, radians):
    ...
    return rotated_tensor

# Get a random angle within a specified range 
r_degrees = 5
angle_range = list(range(-r_degrees, r_degrees))
n = random.randint(angle_range[0], angle_range[len(angle_range)-1])

# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180


# test_tensor = preprocess_simple('path/to/file', (512,512))
test_tensor = torch.randn(1,3,512,512)


# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor, ang_rad)


# Optionally use this to check rotated image
# deprocess_simple(output_tensor, 'rotated_image.jpg')

Some example outputs of what I'm trying to accomplish:

First example of rotated image Second example of rotated image

ProGamerGov
  • 870
  • 1
  • 10
  • 23
  • I suggest you take a look at the Spatial Transformer, especially the grid generator & the sampler modules. See https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html for implementation details. – Gil Pinsky Oct 04 '20 at 19:12
  • @GilPinsky That looks like it has to do with creating trainable layers. I'm going to be only optimizing a single image/tensor, and using rotations to help that optimization. – ProGamerGov Oct 04 '20 at 19:37
  • I would like to point out that kornia solely depends on pytorch, thus would not being any additional dependence overhead – old-ufo Jun 10 '21 at 10:16

3 Answers3

13

So the grid generator and the sampler are sub-modules of the Spatial Transformer (JADERBERG, Max, et al.). These sub-modules are not trainable, they let you apply a learnable, as well as non-learnable, spatial transformation. Here I take these two submodules and use them to rotate an image by theta using PyTorch's functions torch.nn.functional.affine_grid and torch.nn.functional.affine_sample (these functions are implementations of the generator and the sampler, respectively):

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def get_rot_mat(theta):
    theta = torch.tensor(theta)
    return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
                         [torch.sin(theta), torch.cos(theta), 0]])


def rot_img(x, theta, dtype):
    rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
    grid = F.affine_grid(rot_mat, x.size()).type(dtype)
    x = F.grid_sample(x, grid)
    return x


#Test:
dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)

In the example above, assume we take our image, im, to be a dancing cat in a skirt: enter image description here

rotated_im will be a 90-degrees CCW rotated dancing cat in a skirt:

enter image description here

And this is what we get if we call rot_img with theta eqauls to np.pi/4: enter image description here

And the best part that it's differentiable w.r.t the input and has autograd support! Hooray!

Gil Pinsky
  • 2,388
  • 1
  • 12
  • 17
  • 2
    Thanks, your code works amazingly well! Though how should I be handling multiple images stacked across the batch dimension? Should I do each individually with the same rotation matrix, or can it be modified slightly to work without a for statement? – ProGamerGov Oct 04 '20 at 23:52
  • 2
    My pleasure :), I've just updated my code so that it also works across the batch dimension. All you need to do is use ```.repeat(x.shape[0],1,1)``` in ```rot_img``` to repeat the rotation matrix such that it has the same batch dimension as x. – Gil Pinsky Oct 05 '20 at 05:43
  • Sir, would you please explain in the answer ways to rotate one gray scale image? – ToughMind Jan 16 '21 at 13:06
  • 1
    @ToughMind This should work with a single greyscale image as well. Just make sure the image (im) has the appropriate dimensions: 1 x 1 x H x W where H & W are height and width respectively. – Gil Pinsky Jan 17 '21 at 21:34
  • Hi, quick question, if theta was a weight, this could be differentiable w.r.t theta? – Roy Dec 10 '21 at 19:56
  • @iyop45 Unfortunately not (though I could think of a custom differentiable implementation) – Gil Pinsky Dec 11 '21 at 22:45
2

With torchvision it should be simple:

import torchvision.transforms.functional as TF

angle = 30
x = torch.randn(1,3,512,512)

out = TF.rotate(x, angle)

For example if x is:

Kite

out with a 30 degree rotation is (NOTE: counterclockwise):

Kite rotated

cacti5
  • 2,006
  • 2
  • 25
  • 33
1

There is a pytorch function for that:

x = torch.tensor([[0, 1],
            [2, 3]])

x = torch.rot90(x, 1, [0, 1])
>> tensor([[1, 3],
           [0, 2]])

Here are the docs: https://pytorch.org/docs/stable/generated/torch.rot90.html

Theodor Peifer
  • 3,097
  • 4
  • 17
  • 30