1

I'm trying to implement the phase shift property of Fourier Transform with pytorch. What I mean by the shift property is this:

Shifts Property of the Fourier Transform

I think that I've got most of the things correctly but somehow get a noisy image. I'm having a hard time solving this issue. Would it be a numerical issue? Or maybe something due to odd or even pixel numbers? (My images are 1020 x 678 x 3)

These are the shifted image and the original image.

This is my implementation code:

def phase_shifters(y_alpha=0, x_alpha=0, shape=None):
    # HxWxC
    line = torch.zeros(shape)
    # x shift
    line_x = torch.linspace(-shape[1]/2,shape[1]/2,shape[1])
    line_x = line_x.expand(shape[0], shape[2], shape[1]).transpose(1, 2)
    line_x = line_x/shape[1]
    line_x = x_alpha * line_x


    # y shift
    line_y = torch.linspace(-shape[0]/2,shape[0]/2,shape[0])

    line_y = line_y.expand(shape[2], shape[1], shape[0]).transpose(0, 2)
    line_y = line_y/shape[0]
    line_y = y_alpha * line_y


    return x_alpha*line_x + y_alpha*line_y


img = cv2.imread("test.png")
img_fft = torch.fft.fft2(img, dim=(0,1))
mag = torch.abs(img_fft)
phase = torch.angle(img_fft)

# alpha means pixel shift amount in spatial domain!
p_shift = phase_shifters(y_alpha=0,x_alpha=50, shape=phase.shape)
phase = (phase+p_shift) % (2*pi) # for wrapping

recon = torch.polar(mag,phase)
recon = torch.fft.ifft2(recon, dim=(0,1)).real
recon = torch.clamp(recon,0,255)

cv2.imshow("recon",np.array(recon, dtype=np.uint8))
cv2.waitKey(0)
Cris Luengo
  • 55,762
  • 10
  • 62
  • 120
이민규
  • 37
  • 2
  • It looks like you define the frequency axis in the Fourier domain as -n/2 to n/2-1, whereas it really is 0 to sz-1, knowing that the frequency n-1 is equal to the frequency -1. – Cris Luengo Nov 26 '21 at 18:39
  • @CrisLuengo thanks. The image quality is much better! But still not perfect.. I've change the code like below. But it still has some artifacts.. line_x = torch.linspace(0,shape[1]-1, shape[1]) – 이민규 Nov 27 '21 at 05:22
  • @CrisLuengo Found the problem. I should have returned 2pi*(line_x + line_y). Great Thanks !! :) – 이민규 Nov 27 '21 at 05:56

0 Answers0