I generated a non-stationary signal and I want to apply several time-frequency transformation methods such as STFT, Continous Wavelet Transform(CWT), Wigner-Ville Distribution (WVD) to this signal and compare the performance of the results quantitatively. First I tried to generate a matrix as an ideal time-frequency representation to compare the result of several methods with this ideal matrix.
I found several metrics like MSE, cross correlation, renyi entropy and structural similarity. the first problem that occured was that for example for calculation of MSE and cross correlation, both matrices (the ideal one and the result of methods) should have the same size but they don't so i should have found a way to make them the same size. resizing the matrices is not a good idea because the info in the matrix should not change. I also tried down sampling and up sampling but the value of MSE does not make sense. for example for 2 matrices, which are visually similar and both normalized between 0 and 1, the MSE is around 0.7 which is not good. Generally i want to know which metrics are suitable for comparing those time-frequency transformation methods and in which way i can make the matrices the same size so i can get reasonable answers.
here is my code:
import numpy as np
from scipy import signal
import ssqueezepy as ssqpy
import scipy as sp
import math
from tftb.processing.reassigned import pseudo_wigner_ville
import matplotlib.pyplot as plt
from matrix_ideal import matrix_ideal
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import mean_squared_error
def normalize_matrix(matrix):
normalized = (matrix - np.min(matrix)) / ((np.max(matrix)) - np.min(matrix))
return normalized
def renyi_entropy(TF, gamma):
TF_norm = TF / np.sum(TF)
renyi_entropy = (1 / (1 - gamma)) * np.log2(np.sum(np.power(TF_norm, gamma)))
return renyi_entropy
def SET(x,hlength=None):
xrow = len(x)
N = xrow
if hlength is None:
hlength = round(xrow/8)
t = list(range(1,N+1))
ft = list(range(1,round(N/2)+1))
tcol = len(t)
hlength = hlength+1 - (hlength%2)
ht = np.linspace(-0.5,0.5,hlength,endpoint=True)
ht = np.transpose(ht)
h = np.exp((-math.pi) /0.32**2*np.power(ht,2))
dh = np.multiply(-2*math.pi/0.32**2*ht,h)
hrow = len(h)
Lh = (hrow - 1)/2
tfr1 = np.zeros((N, tcol))
tfr2 = np.zeros((N, tcol))
for icol in range(1,tcol+1):
ti = t[icol-1]
tau = np.array(range(-min([round(N/2)-1,int(Lh), int(ti)-1]),min([round(N/2)-1,int(Lh),xrow-int(ti)])+1))
indices = ((N+tau)%N)+1
p = ti + tau
vector = np.vectorize(np.int_)
q = p-1
rSig = x[vector(q)]
v = h[vector(Lh + tau)]
tfr1[indices-1, icol-1] = rSig * np.conj(h[vector(Lh + tau)])
w = dh[vector(Lh + tau)]
tfr2[indices-1, icol-1] = rSig * np.conj(dh[vector(Lh + tau)])
tfr1 = sp.fft.fft(tfr1,axis=0)
tfr2 = sp.fft.fft(tfr2,axis=0)
tfr1 = tfr1[0:round(N/2),:]
tfr2 = tfr2[0:round(N/2),:]
complex_num = np.zeros((round(N/2),N))
va = N/hlength
IF = np.zeros((round(N / 2),tcol))
tfr = np.zeros((round(N / 2), tcol))
E = np.mean(abs(x))
for i in range(0,round(N/2)):
for j in range(0,N):
if abs(tfr1[i,j]) > 0.8*E:
complex_num[i,j] = va * 1j * tfr2[i,j]/2/np.pi/tfr1[i,j]
if abs(-complex_num[i,j].real) < 0.5 :
IF[i,j] = 1
tfr = tfr1 / (sum(h) / 2)
Te = tfr * IF
return IF, Te, tfr
f1_L = 1500 # 50
f2_L = 3000 #2000 # 600
f3_L = 3500
f4_L = 4500
# Define the amplitude of the signal
A1 = 10
A2 = 5
A3 = 3
A4 = 8
factor_fs_down = 4
# Define the time points at which the signal will be sampled
fs_L = 42000//factor_fs_down # sampling frequency
dt = 1 / fs_L
T = 1 # duration of the signal in seconds
t_1_L = np.linspace(0.1, 0.3, num=8400//factor_fs_down, endpoint=True)
t_2_L = np.linspace(0.2, 0.4, num=8400//factor_fs_down, endpoint=True)
t_3_L = np.linspace(0.5, 0.7, num=8400//factor_fs_down, endpoint=True)
t_4_L = np.linspace(0.4, 0.9, num=21000//factor_fs_down, endpoint=True)
t_L = np.linspace(0, T, int(fs_L * T), endpoint=True)
ts_L = np.arange(fs_L) * dt
part_1_L = A1*np.sin(2 * np.pi * f1_L * t_1_L)
part_2_L = A2*np.sin(2 * np.pi * f2_L * t_2_L)
part_3_L = A3*np.sin(2 * np.pi * f3_L* t_3_L)
part_4_L = A4*np.sin(2 * np.pi * f4_L * t_4_L)
array_zero_new = np.zeros(25200//factor_fs_down)
array_zero_L = np.zeros(4200//factor_fs_down)
array_zero_between_L = np.zeros(8400//factor_fs_down)
# Generate the signal
signal_L_part_1 = np.concatenate((array_zero_L, part_1_L, np.zeros(29400//factor_fs_down)), axis=0)
signal_L_part_2 = np.concatenate((array_zero_between_L, part_2_L, array_zero_new), axis=0)
signal_L_part_3 = np.concatenate((np.zeros(21000//factor_fs_down),part_3_L, np.zeros(12600//factor_fs_down)), axis=0)
signal_L_part_4 = np.concatenate((np.zeros(16800//factor_fs_down),part_4_L, array_zero_L), axis=0)
x = signal_L_part_1 + signal_L_part_2 + signal_L_part_3 + signal_L_part_4
matrix_ideal_reference = np.zeros((int(fs_L/2),len(t_L)))
matrix_ideal_reference[f1_L,int(t_1_L[0]*fs_L):int(t_1_L[-1]*fs_L)] = A1
matrix_ideal_reference[f2_L,int(t_2_L[0]*fs_L):int(t_2_L[-1]*fs_L)] = A2
matrix_ideal_reference[f3_L,int(t_3_L[0]*fs_L):int(t_3_L[-1]*fs_L)] = A3
matrix_ideal_reference[f4_L,int(t_4_L[0]*fs_L):int(t_4_L[-1]*fs_L)] = A4
# Calculate the PSD of the signal using the Welch method
gamma = 3
#compute renyi for x
f_x, Pxx_x = signal.welch(np.abs(matrix_ideal_reference)**2, fs=fs_L, nperseg=512)
renyi_entropy_x = renyi_entropy(Pxx_x, gamma)
print('renyi entropy of x without noise:',renyi_entropy_x)
#compute WVD, PWVD
tfr_wvd, tfr_pwvd, _ = pseudo_wigner_ville(x, n_fbins=1024) #, timestamps=1/fs_L)
tfr_wvd_usefull = np.abs(tfr_wvd)[0:(tfr_wvd.shape[0]//2),:]
tfr_pwvd_usefull = np.abs(tfr_pwvd)[0:(tfr_pwvd.shape[0]//2),:]
f_wvd, Pxx_wvd = signal.welch(np.abs(tfr_wvd_usefull)**2, fs=fs_L, nperseg=512)
f_pwvd, Pxx_pwvd = signal.welch(np.abs(tfr_pwvd_usefull)**2, fs=fs_L, nperseg=512)
renyi_entropy_wvd = renyi_entropy(Pxx_wvd, gamma)
renyi_entropy_pwvd = renyi_entropy(Pxx_pwvd, gamma)
print('renyi entropy of wvd without noise:',renyi_entropy_wvd)
print('renyi entropy of pwvd without noise:',renyi_entropy_pwvd)
#compute Cwt, SSQ-CWT
Tx_ssq, Wx_cwt, ssq_freqs,scales,*_ = ssqpy.ssq_cwt(x, wavelet='morlet', scales='log-piecewise')
f_cwt, Pxx_cwt = signal.welch(np.abs(Wx_cwt)**2, fs=fs_L, nperseg=512)
f_ssq_cwt, Pxx_ssq_cwt = signal.welch(np.abs(Tx_ssq)**2, fs=fs_L, nperseg=512)
renyi_entropy_cwt = renyi_entropy(Pxx_cwt, gamma)
renyi_entropy_ssq_cwt = renyi_entropy(Pxx_ssq_cwt, gamma)
print('renyi entropy of cwt without noise:',renyi_entropy_cwt)
print('renyi entropy of ssq-cwt without noise:',renyi_entropy_ssq_cwt)
# compute SET
IF, Te, tfr = SET(x,128)
f_SET, Pxx_SET = signal.welch(np.abs(Te)**2, fs=fs_L, nperseg=512)
renyi_entropy_SET = renyi_entropy(Pxx_SET, gamma)
print('renyi entropy of SET without noise:',renyi_entropy_SET)
# compute STFT
f, t, Zxx = signal.stft(x, fs=fs_L, nperseg=512)
f_stft, Pxx_stft = signal.welch(np.abs(Zxx)**2, fs=fs_L, nperseg=512)
renyi_entropy_stft = renyi_entropy(Pxx_stft, gamma)
print('renyi entropy of stft without noise:',renyi_entropy_stft)
#compute SSQ-STFT
Tx_ssq_stft, Sx_stft, ssq_freqs_Tx, Sfs,*_ = ssqpy.ssq_stft(x, window='hann', n_fft=512, win_len=512, hop_len=256, fs=fs_L, flipud=True)
f_ssq_stft, Pxx_ssq_stft = signal.welch(np.abs(Tx_ssq_stft)**2, fs=fs_L, nperseg=512)
renyi_entropy_ssq_stft = renyi_entropy(Pxx_ssq_stft, gamma)
print('renyi entropy of ssq-stft without noise:',renyi_entropy_ssq_stft)
# define snr values
snr_values = [0,5, 10,15, 20, 25, 30]
hlength = 128
# Define STFT parameters
nperseg = 512 # Length of each STFT segment
noverlap = nperseg // 2 # Number of points of overlap between adjacent STFT segments
n_fbins = 1024
# Initialize empty lists to store the Renyi entropy and SNR values
ref_renyi_values = []
stft_renyi_values = []
ssq_stft_renyi_values = []
cwt_renyi_values = []
ssq_cwt_renyi_values = []
wvd_renyi_values = []
pwvd_renyi_values = []
SET_renyi_values = []
stft_ssim_values = []
ssq_stft_ssim_values = []
cwt_ssim_values = []
ssq_cwt_ssim_values = []
wvd_ssim_values = []
pwvd_ssim_values = []
SET_ssim_values = []
stft_MSE_values = []
ssq_stft_MSE_values = []
cwt_MSE_values = []
ssq_cwt_MSE_values = []
wvd_MSE_values = []
pwvd_MSE_values = []
SET_MSE_values = []
snr_list = []
check_p_i_values = []
# Loop over each SNR value and add noise to the signal, then apply methods and calculate Renyi entropy and SSIM
for snr in snr_values:
# Add noise to the signal
noise_power = np.var(x) / (10**(snr/10))
noise = np.sqrt(noise_power) * np.random.randn(len(x))
noisy_signal = x + noise
#compute WVD, PWVD
tfr_wvd, tfr_pwvd, _ = pseudo_wigner_ville(noisy_signal, n_fbins=1024) #, timestamps=1/fs_L)
tfr_wvd_usefull = np.abs(tfr_wvd)[0:(tfr_wvd.shape[0]//2),:]
tfr_pwvd_usefull = np.abs(tfr_pwvd)[0:(tfr_pwvd.shape[0]//2),:]
f_wvd, Pxx_wvd = signal.welch(np.abs(tfr_wvd_usefull)**2, fs=fs_L, nperseg=512)
f_pwvd, Pxx_pwvd = signal.welch(np.abs(tfr_pwvd_usefull)**2, fs=fs_L, nperseg=512)
renyi_entropy_wvd = renyi_entropy(Pxx_wvd, gamma)
renyi_entropy_pwvd = renyi_entropy(Pxx_pwvd, gamma)
print('renyi entropy of wvd {0} : {1}'.format(renyi_entropy_wvd,snr))
print('renyi entropy of pwvd {0} : {1}'.format(renyi_entropy_pwvd,snr))
tfr_wvd_usefull_normalized = normalize_matrix(tfr_wvd_usefull)
tfr_pwvd_usefull_normalized = normalize_matrix(tfr_pwvd_usefull)
matrix_ideal_wvd_normalized = normalize_matrix(matrix_ideal(tfr_wvd_usefull, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
matrix_ideal_pwvd_normalized = normalize_matrix(matrix_ideal(tfr_pwvd_usefull, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
ssim_wvd = ssim(tfr_wvd_usefull_normalized, np.flipud(matrix_ideal_wvd_normalized))
ssim_pwvd = ssim(tfr_pwvd_usefull_normalized, np.flipud(matrix_ideal_pwvd_normalized))
MSE_wvd = mean_squared_error(np.flipud(matrix_ideal_wvd_normalized), tfr_wvd_usefull_normalized)
MSE_pwvd = mean_squared_error(np.flipud(matrix_ideal_pwvd_normalized), tfr_pwvd_usefull_normalized)
#compute Cwt, SSQ-CWT
Tx_ssq, Wx_cwt, ssq_freqs,scales,*_ = ssqpy.ssq_cwt(noisy_signal, wavelet='morlet', scales='log-piecewise')
f_cwt, Pxx_cwt = signal.welch(np.abs(Wx_cwt)**2, fs=fs_L, nperseg=512)
f_ssq_cwt, Pxx_ssq_cwt = signal.welch(np.abs(Tx_ssq)**2, fs=fs_L, nperseg=512)
renyi_entropy_cwt = renyi_entropy(Pxx_cwt, gamma)
renyi_entropy_ssq_cwt = renyi_entropy(Pxx_ssq_cwt, gamma)
print('renyi entropy of cwt {0} : {1}'.format(renyi_entropy_cwt,snr))
print('renyi entropy of ssq-cwt {0} : {1}'.format(renyi_entropy_ssq_cwt,snr))
Wx_cwt_normalized = normalize_matrix(abs(Wx_cwt))
Tx_ssq_normalized = normalize_matrix(abs(Tx_ssq))
matrix_ideal_cwt_normalized = normalize_matrix(matrix_ideal(Wx_cwt, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
matrix_ideal_ssq_cwt_normalized = normalize_matrix(matrix_ideal(Tx_ssq, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
ssim_cwt = ssim(Wx_cwt_normalized, np.flipud(matrix_ideal_cwt_normalized))
ssim_ssq_cwt = ssim(Tx_ssq_normalized, np.flipud(matrix_ideal_ssq_cwt_normalized))
MSE_cwt = mean_squared_error(np.flipud(matrix_ideal_cwt_normalized), Wx_cwt_normalized)
MSE_ssq_cwt = mean_squared_error(np.flipud(matrix_ideal_ssq_cwt_normalized), Tx_ssq_normalized)
# compute SET
IF, Te, tfr = SET(noisy_signal,128)
f_SET, Pxx_SET = signal.welch(np.abs(Te)**2, fs=fs_L, nperseg=512)
renyi_entropy_SET = renyi_entropy(Pxx_SET, gamma)
print('renyi entropy of SET {0} : {1}'.format(renyi_entropy_SET,snr))
Te_normalized = normalize_matrix(abs(Te))
matrix_ideal_SET_normalized = normalize_matrix(matrix_ideal(Te, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
ssim_SET = ssim(Te_normalized, matrix_ideal_SET_normalized)
MSE_SET = mean_squared_error(matrix_ideal_SET_normalized, Te_normalized)
# compute STFT
f, t, Zxx = signal.stft(noisy_signal, fs=fs_L, nperseg=512)
f_stft, Pxx_stft = signal.welch(np.abs(Zxx)**2, fs=fs_L, nperseg=512)
renyi_entropy_stft = renyi_entropy(Pxx_stft, gamma)
print('renyi entropy of stft {0} : {1}'.format(renyi_entropy_stft,snr))
Zxx_normalized = normalize_matrix(abs(Zxx))
matrix_ideal_stft_normalized = normalize_matrix(matrix_ideal(Zxx, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
ssim_stft = ssim(Zxx_normalized, matrix_ideal_stft_normalized)
MSE_stft = mean_squared_error(matrix_ideal_stft_normalized, Zxx_normalized)
#compute SSQ-STFT
Tx_ssq_stft, Sx_stft, ssq_freqs_Tx, Sfs,*_ = ssqpy.ssq_stft(noisy_signal, window='hann', n_fft=512, win_len=512, hop_len=256, fs=fs_L, flipud=True)
f_ssq_stft, Pxx_ssq_stft = signal.welch(np.abs(Tx_ssq_stft)**2, fs=fs_L, nperseg=512)
renyi_entropy_ssq_stft = renyi_entropy(Pxx_ssq_stft, gamma)
print('renyi entropy of ssq-stft {0} : {1}'.format(renyi_entropy_ssq_stft,snr))
Tx_ssq_stft_normalized = normalize_matrix(abs(Tx_ssq_stft))
matrix_ideal_ssq_stft_normalized = normalize_matrix(matrix_ideal(Tx_ssq_stft, f1_L, f2_L, f3_L, f4_L, fs_L, A1, A2, A3, A4, t_1_L, t_2_L, t_3_L, t_4_L))
ssim_ssq_stft = ssim(Tx_ssq_stft_normalized, matrix_ideal_ssq_stft_normalized)
MSE_ssq_stft = mean_squared_error(matrix_ideal_ssq_stft_normalized, Tx_ssq_stft_normalized)
# creating the list of values
stft_renyi_values.append(renyi_entropy_stft)
stft_ssim_values.append(ssim_stft)
stft_MSE_values.append(MSE_stft)
ssq_stft_renyi_values.append(renyi_entropy_ssq_stft)
ssq_stft_ssim_values.append(ssim_ssq_stft)
ssq_stft_MSE_values.append(MSE_ssq_stft)
cwt_renyi_values.append(renyi_entropy_cwt)
cwt_ssim_values.append(ssim_cwt)
cwt_MSE_values.append(MSE_cwt)
ssq_cwt_renyi_values.append(renyi_entropy_ssq_cwt)
ssq_cwt_ssim_values.append(ssim_ssq_cwt)
ssq_cwt_MSE_values.append(MSE_ssq_cwt)
wvd_renyi_values.append(renyi_entropy_wvd)
wvd_ssim_values.append(ssim_wvd)
wvd_MSE_values.append(MSE_wvd)
pwvd_renyi_values.append(renyi_entropy_pwvd)
pwvd_ssim_values.append(ssim_pwvd)
pwvd_MSE_values.append(MSE_pwvd)
SET_renyi_values.append(renyi_entropy_SET)
SET_ssim_values.append(ssim_SET)
SET_MSE_values.append(MSE_SET)
snr_list.append(snr)
plt.figure(figsize=(10, 6))
plt.plot(snr_list, stft_renyi_values, marker='o',label='STFT')
#plt.title('Renyi entropy of STFT with different SNRs')
plt.xlabel('SNR (dB)')
plt.ylabel('Renyi entropy')
#plt.show()
plt.plot(snr_list, ssq_stft_renyi_values, marker='o',label='SSQ-STFT')
plt.plot(snr_list, cwt_renyi_values, marker='o',label='CWT')
plt.title('Renyi entropy of several methods with different SNRs')
plt.plot(snr_list, ssq_cwt_renyi_values, marker='o',label='SSQ-CWT')
#plt.title('Renyi entropy of SSQ-CWT with different SNRs')
plt.plot(snr_list, wvd_renyi_values, marker='o',label='WVD')
#plt.title('Renyi entropy of SSQ-STFT with different SNRs')
plt.plot(snr_list, pwvd_renyi_values, marker='o',label='Pseudo-WVD')
#plt.title('Renyi entropy of CWT with different SNRs')
plt.plot(snr_list, SET_renyi_values, marker='o',label='SET')
plt.legend()
plt.show()
plt.figure()
plt.plot(snr_list, stft_ssim_values, marker='o',label='STFT')
#plt.title('Renyi entropy of STFT with different SNRs')
plt.plot(snr_list, ssq_stft_ssim_values, marker='o',label='SSQ-STFT')
plt.xlabel('SNR (dB)')
plt.ylabel('Structural Similarity')
plt.title('structural similarity of several methods with different SNRs')
plt.plot(snr_list, cwt_ssim_values, marker='o',label='CWT')
plt.plot(snr_list, ssq_cwt_ssim_values, marker='o',label='SSQ-CWT')
#plt.title('Renyi entropy of SSQ-CWT with different SNRs')
plt.plot(snr_list, wvd_ssim_values, marker='o',label='WVD')
#plt.title('Renyi entropy of SSQ-STFT with different SNRs')
plt.plot(snr_list, pwvd_ssim_values, marker='o',label='Pseudo-WVD')
#plt.title('Renyi entropy of CWT with different SNRs')
plt.plot(snr_list, SET_ssim_values, marker='o',label='SET')
plt.legend()
plt.show()
plt.figure()
plt.plot(snr_list, stft_MSE_values, marker='o',label='STFT')
#plt.title('Renyi entropy of STFT with different SNRs')
plt.plot(snr_list, ssq_stft_MSE_values, marker='o',label='SSQ-STFT')
plt.xlabel('SNR (dB)')
plt.ylabel('MSE')
plt.title('MSE of several methods with different SNRs')
plt.plot(snr_list, cwt_MSE_values, marker='o',label='CWT')
plt.plot(snr_list, ssq_cwt_MSE_values, marker='o',label='SSQ-CWT')
#plt.title('Renyi entropy of SSQ-CWT with different SNRs')
plt.plot(snr_list, wvd_MSE_values, marker='o',label='WVD')
#plt.title('Renyi entropy of SSQ-STFT with different SNRs')
plt.plot(snr_list, pwvd_MSE_values, marker='o',label='Pseudo-WVD')
#plt.title('Renyi entropy of CWT with different SNRs')
plt.plot(snr_list, SET_MSE_values, marker='o',label='SET')
plt.legend()
plt.show()
if you plot each matrix_ideal and the result of each method, you can see that the value of MSE or SSIM are not correct.