I want to compute the convolution operation using FFT.
Apply FFT to the two input tensors and generate output tensors using element-wise multiplication. This output tensor is then subjected to IFFT to produce the final output.
The code below is the pseudo code I wrote.
input tensors:
import torch
a = torch.tensor([0, 1, 2, 3])
b = torch.tensor([1, 2])
pseudo code:
fft_a = torch.fft.fft(a)
fft_b = torch.fft.fft(b, n=4)
out = fft_a * fft_b
out = torch.fft.ifft(out)
expected output:
print(out.real)
>>> tensor([2., 5., 8.])
real output:
print(out.real)
>>> tensor([6., 1., 4., 7.])
How can I perform the convolution operation using FFT?