I'm trying to implement the phase shift property of Fourier Transform with pytorch. What I mean by the shift property is this:
I think that I've got most of the things correctly but somehow get a noisy image. I'm having a hard time solving this issue. Would it be a numerical issue? Or maybe something due to odd or even pixel numbers? (My images are 1020 x 678 x 3)
These are the shifted image and the original image.
This is my implementation code:
def phase_shifters(y_alpha=0, x_alpha=0, shape=None):
# HxWxC
line = torch.zeros(shape)
# x shift
line_x = torch.linspace(-shape[1]/2,shape[1]/2,shape[1])
line_x = line_x.expand(shape[0], shape[2], shape[1]).transpose(1, 2)
line_x = line_x/shape[1]
line_x = x_alpha * line_x
# y shift
line_y = torch.linspace(-shape[0]/2,shape[0]/2,shape[0])
line_y = line_y.expand(shape[2], shape[1], shape[0]).transpose(0, 2)
line_y = line_y/shape[0]
line_y = y_alpha * line_y
return x_alpha*line_x + y_alpha*line_y
img = cv2.imread("test.png")
img_fft = torch.fft.fft2(img, dim=(0,1))
mag = torch.abs(img_fft)
phase = torch.angle(img_fft)
# alpha means pixel shift amount in spatial domain!
p_shift = phase_shifters(y_alpha=0,x_alpha=50, shape=phase.shape)
phase = (phase+p_shift) % (2*pi) # for wrapping
recon = torch.polar(mag,phase)
recon = torch.fft.ifft2(recon, dim=(0,1)).real
recon = torch.clamp(recon,0,255)
cv2.imshow("recon",np.array(recon, dtype=np.uint8))
cv2.waitKey(0)