3

When using np.fft.fft2 on images, the result is of the same size as the input. For real images, the real-to-complex FT has a symmetry where ft[i,j] == ft[-i,-j].conj(), as explained in this answer. For this reason, some frameworks such as PyTorch or scikit-cuda, return a FT of shape (height // 2 +1, width // 2 + 1). Now, given a redundancy-free/one-sided FT, how can I use numpy index magic to map it to the full FT output by numpy?


Background: I need this for translating some numpy code.

oarfish
  • 4,116
  • 4
  • 37
  • 66
  • Are you sure about the size, `(height // 2 +1, width // 2 + 1)`? That doesn't seem like enough info to reconstruct the full FFT. – Bi Rico Jul 22 '18 at 15:51
  • No so sure, but that's what PyTorch and scikit-cuda return. – oarfish Jul 22 '18 at 15:58
  • The example at the end of the documentation for [`fft`](https://pytorch.org/docs/stable/torch.html?highlight=fft#torch.fft) seems to return an array the same size as the input. – Bi Rico Jul 22 '18 at 16:03
  • Exactly half of a multi-dimensional DFT (FFT) is redundant, if the input is real-valued. This is true for 1D, 2D, 3D, etc. – Cris Luengo Jul 22 '18 at 19:58
  • That is to say, if PyTorch returns one quarter of the 2D transform, it returns only half the information in the image. It’s not complete. – Cris Luengo Jul 22 '18 at 20:03
  • @Bi Rico The complex-complex transform does, yes, but not the real-to-complex one (rfft). – oarfish Jul 22 '18 at 21:13
  • @CrisLuengo It seems that PyTorch at least returns half of the elements only in the last dimension, the first one is full. I'm not sure how to work with this, I am used only to full-sized fourier transforms. – oarfish Jul 23 '18 at 08:00

2 Answers2

1

If you are using torch.rfft, then you can set onesided=False to get the full transform back.

That documentation doesn't say anything about how the output is formatted, the best guess is to assume it returns the first half of the elements along the last dimension, meaning that ft[i,j], with i in half-open range [0,in.shape[0]), j in half-open range [0,in.shape[1]), and in the input image, can be read as follows:

cutoff = in.shape[1] // 2 + 1
if j >= cutoff:
   return ft[-i, in.shape[1] - j].conj()
else:
   return ft[i, j]

If you use skcuda.fft.fft, the documentation is equally explicit, and therefore I'd make the same guess as above.


To obtain a full DFT out of the half-plane DFT returned by these functions, use the following code:

import numpy as np

size = 6
halfsize = size // 2 + 1
half_ft = np.random.randn(size, halfsize) # example half-plane DFT

if size % 2:
   # odd size input
   other_half = half_ft[:, -1:0:-1]
else:
   # even size input
   other_half = half_ft[:, -2:0:-1]

other_half = np.conj(other_half[np.hstack((0, np.arange(size-1, 0, -1))), :])
full_ft = np.hstack((half_ft, other_half))

That is, we flip the array along both dimensions (this is the 2D case, adjust as needed for other dimensionalities), but the first row and column (DC components) are not repeated, and for even-sized input, the last row and column are not repeated either.

Cris Luengo
  • 55,762
  • 10
  • 62
  • 120
  • What I'm specifically interested in is doing this without loops, since that will be slow I guess. – oarfish Jul 31 '18 at 14:44
  • @oarfish: Does the `onesided=False` parameter not work then? – Cris Luengo Jul 31 '18 at 15:07
  • The question is not directly related to pytorch. That one framework may have a parameter to do it (I seem to remember it didn't do exactly what I wanted, but nevermind that), but I'm looking for the general algorithm to do it as fast as possible. – oarfish Jul 31 '18 at 15:37
  • @oarfish: OK, see updated answer. It's just a matter of replicating the relevant portion of the DFT result. – Cris Luengo Jul 31 '18 at 17:23
0

I finally succeeded in using np.meshgrid properly to fill in the relevant data. We can use ranges for the entire row range and the missing part of the column range to only fill these indices with the appropriate data.

import numpy as np
np.random.seed(0)

N     = 10
image = np.random.rand(N, N)
h, w  = image.shape

ft           = np.fft.rfft2(image)
ft_reference = np.fft.fft2(image)
ft_full      = np.zeros_like(image, dtype=np.complex128)
ft_full[:ft.shape[0], :ft.shape[1]] = ft

X, Y          = np.meshgrid(range(h), range(w // 2 + 1, w), indexing='ij')
ft_full[X, Y] = ft_full[-X, -Y].conj()
print(np.allclose(ft_full, ft_reference))
oarfish
  • 4,116
  • 4
  • 37
  • 66