I have a C++ code which performs fftshift
for a given meshgrid
.
const std::size_t size = 5;
auto ar = xt::meshgrid(xt::arange<double>(0, size), xt::arange<double>(0, size));
int translate = (size + 1) / 2;
xt::xarray<double> x = std::get<0>(ar) - translate;
xt::xarray<double> y = std::get<1>(ar) - translate;
xt::xarray<double> xy_ = xt::stack(xt::xtuple(x, y));
auto p = xt::fftw::fftshift(xy_);
std::cout << p << std::endl;
Which gives the following shifted matrix:
{{{-3., -2., -1., 0., 1.},
{-3., -2., -1., 0., 1.},
{-3., -2., -1., 0., 1.},
{-3., -2., -1., 0., 1.},
{-3., -2., -1., 0., 1.}},
{{-1., -1., -1., -1., -1.},
{ 0., 0., 0., 0., 0.},
{ 1., 1., 1., 1., 1.},
{-3., -3., -3., -3., -3.},
{-2., -2., -2., -2., -2.}}}
whereas for python the same fftshift
()results in:
np.mgrid[:size, :size] - int( (size + 1)/2 )
fftshifted_mat = scipy.fftpack.fftshift(mat)
print(fftshifted_mat)
[[[ 0 1 -3 -2 -1]
[ 0 1 -3 -2 -1]
[ 0 1 -3 -2 -1]
[ 0 1 -3 -2 -1]
[ 0 1 -3 -2 -1]]
[[ 0 0 0 0 0]
[ 1 1 1 1 1]
[-3 -3 -3 -3 -3]
[-2 -2 -2 -2 -2]
[-1 -1 -1 -1 -1]]]
How can I make the c++ fftshift output matrix exactly equal to scipy's output matrix?
I tried using xt::roll
, xt::transpose + xt::swap
, and manual circular shift combinations, but none of them worked.
Update:
Tried using roll
for (std::size_t axis = 0; axis < xy_.shape().size(); ++axis){
std::size_t dim_size = xy_.shape()[axis];
std::size_t shift = (dim_size - 1) / 2;
xy_ = xt::roll(xy_, shift, axis);
}
However, for some reason only getting right matrix that is same as the scipy.fft.fftshift one with size = 5 or size = 125. I am not sure why?
Update 2: As per @chris' answer I added manual shift with roll. It seems to replicate scipy's fftshift but seems to be quite slow.
template <typename T>
void fftshift_roll(xt::xarray<T>& array)
{
std::size_t ndims = array.dimension();
std::vector<std::ptrdiff_t> shift_indices(ndims);
for (std::size_t i = 0; i < ndims; ++i) {
std::ptrdiff_t shift = static_cast<std::ptrdiff_t>(array.shape(i)) / 2;
shift_indices[i] = shift;
}
for (std::size_t i = 0; i < ndims; ++i) {
auto rolled = xt::roll(array, shift_indices[i], i);
array = xt::view(rolled, xt::all(), xt::all());
}
}