I'm trying to write a convolution code entirely in the spectral domain. I'm taking a spike series in time (example below only has one spike for simplicity) of n samples and calculating the Fourier series with numpy.fft.fft
. I create a 'Ricker wavelet' of m samples (m << n) and calculate its Fourier series with numpy.fft.fft
, but specifying that its output Fourier series be n samples long. Both the spike series and wavelet have the same sampling interval. The resulting convolved series is shifted (peak of wavelet is shifted along the time axis with respect to the spike). This shift seems to depend on the size, m, of the wavelet.
I thought it had something to do with the parameters of numpy.fft.fft(a, n=None, axis=-1, norm=None)
, particularly the 'axis' parameter. But, I do not understand the documentation for this parameter at all.
Can anyone help me understand why I'm getting this shift (if it isn't clear, let me be explicit and say that the peak of the wavelet in the convolved series must the at the same time sample of the spike in the input spike series)?
My code follows:
################################################################################
#
# import libraries
#
import math
import numpy as np
import scipy
import matplotlib.pyplot as plt
import os
from matplotlib.ticker import MultipleLocator
from random import random
# Define lists
#
Time=[]; Ricker=[]; freq=25; rickersize=51; timeiter=0.002; serieslength=501; TIMElong=[]; Reflectivity=[];
Series=[]; IMPEDANCE=[]; CONVOLUTION=[];
#
# Create ricker wavelet and its time sequence
#
for i in range(0,rickersize):
time=(float(i-rickersize//2)*timeiter)
ricker=(1-2*math.pi*math.pi*freq*freq*time*time)*math.exp(-1*math.pi*math.pi*freq*freq*time*time)
Time.append(time)
Ricker.append(ricker)
#
# Do various FFT operations on the Ricker wavelet:
# Normal FFT, FFT of longer Ricker, Amplitude of the FFTs, their inverse FFTs and their frequency sequence
#
FFT=np.fft.fft(Ricker); FFTlong=np.fft.fft(Ricker,n=serieslength,axis=0,norm=None);
AMP=abs(FFT); AMPlong=abs(FFTlong);
RICKER=np.fft.ifft(FFT); RICKERlong=np.fft.ifft(FFTlong);
FREQ=np.fft.fftfreq(len(Ricker),d=timeiter); FREQlong=np.fft.fftfreq(len(RICKERlong),d=timeiter)
PHASE=np.angle(FFT); PHASElong=np.angle(FFTlong);
#
# Create a single spike in the otherwise empty (0) series of length 'serieslength' (=len(RICKERlong)
# this spikes mimics a very simple seismic reflectivity series in time
#
for i in range(0,serieslength):
time=(float(i)*timeiter)
TIMElong.append(time)
if i==int(serieslength/2):
Series.append(1)
else:
Series.append(0)
#
# Do various FFT operations on the spike series
# Normal FFT, Amplitude of the FFT, its inverse FFT and frequency sequence
#
FFTSeries=np.fft.fft(Series)
AMPSeries=abs(FFTSeries)
SERIES=np.fft.ifft(FFTSeries)
FREQSeries=np.fft.fftfreq(len(Series),d=timeiter)
#
# Do convolution of the spike series with the (long) Ricker wavelet in the frequency domain and see result via inverse FFT
#
FFTConvolution=[FFTlong[i]*FFTSeries[i] for i in range(len(Series))]
CON=np.fft.ifft(FFTConvolution)
CONVOLUTION=[CON[i].real for i in range(len(Series))]
#
# plotting routines
#
fig,axs = plt.subplots(nrows=1,ncols=3, figsize=(14,8))
axs[0].barh(TIMElong,Series,height=0.005, color='black')
axs[1].plot(Ricker,Time,color='black', linestyle='solid',linewidth=1)
axs[2].plot(CONVOLUTION,TIMElong,color='black', linestyle='solid',linewidth=1)
#
axs[0].set_aspect(aspect=8); axs[0].set_title('Reflectivity',fontsize=12); axs[0].yaxis.grid(); axs[0].xaxis.grid();
axs[0].set_xlim(-2,2); axs[0].set_ylim(min(TIMElong),max(TIMElong)); axs[0].invert_yaxis(); axs[0].tick_params(axis='both',which='major',labelsize=12);
#
axs[1].set_aspect(aspect=6.2); axs[1].set_title('Ricker',fontsize=12); axs[1].yaxis.grid(); axs[1].xaxis.grid();
axs[1].set_xlim(-1.0,1.02); axs[1].set_ylim(min(Time),max(Time)); axs[1].invert_yaxis(); axs[1].tick_params(axis='both',which='major',labelsize=12);
#
axs[2].set_aspect(aspect=8); axs[2].set_title('Convolution',fontsize=12); axs[2].yaxis.grid(); axs[2].xaxis.grid();
axs[2].set_xlim(-2,2); axs[2].set_ylim(min(TIMElong),max(TIMElong)); axs[2].invert_yaxis(); axs[2].tick_params(axis='both',which='major',labelsize=12);
#
fig.tight_layout()
fig.show()
####