I have a standard gaussian function, that looks like this:
def gauss_fnc(x, amp, cen, sigma):
return amp * np.exp(-(x - cen) ** 2 / (2 * sigma ** 2))
And I have a fit_gaussian function that uses scipy's curve_fit to fit my gauss_fnc:
from scipy.optimize import curve_fit
def fit_gaussian(x, y):
mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean) ** 2) / sum(y))
opt, cov = curve_fit(gauss_fnc, x, y, p0=[max(y), mean, sigma])
values = gauss_fnc(x, *opt)
return values, sigma, opt, cov
I can confirm that this works great if the data resembles a normal gaussian function, see example:
However if the signal is too peaked or too narrow, it won't work as expected. Example of a peaked gaussian:
Here is an example of a flat-top or super gaussian:
Currently the flatter the gaussian becomes, more and more information is lost, due to gaussian cutting down the edges. How can I improve the functions, or the curve fitting in order to be able to fit peaked and flat-top signals as well like in this picture:
Edit:
I provided a minimal working example to try this out:
from PyQt5.QtWidgets import (QApplication, QMainWindow)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from scipy.optimize import curve_fit
import numpy as np
from PyQt5.QtWidgets import QWidget, QGridLayout
def gauss_fnc(x, amp, cen, sigma):
return amp * np.exp(-(x - cen) ** 2 / (2 * sigma ** 2))
def fit_gauss(x, y):
mean = sum(x * y) / sum(y)
sigma = np.sqrt(sum(y * (x - mean) ** 2) / sum(y))
opt, cov = curve_fit(gauss_fnc, x, y, p0=[max(y), mean, sigma])
vals = gauss_fnc(x, *opt)
return vals, sigma, opt, cov
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.results = list()
self.setWindowTitle('Gauss fitting')
self.setGeometry(50, 50, 1280, 1024)
self.setupLayout()
self.raw_data1 = np.array([1, 1, 1, 1, 3, 5, 7, 8, 9, 10, 11, 10, 9, 8, 7, 5, 3, 1, 1, 1, 1], dtype=int)
self.raw_data2 = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 200, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int)
self.raw_data3 = np.array([1, 1, 1, 1, 1, 3, 5, 9, 10, 10, 10, 10, 10, 9, 5, 3, 1, 1, 1, 1, 1], dtype=int)
self.plot()
def setupLayout(self):
# Create figures
self.fig1 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig1AX = self.fig1.figure.add_subplot(111, frameon=False)
self.fig1AX.get_xaxis().set_visible(True)
self.fig1AX.get_yaxis().set_visible(True)
self.fig1AX.yaxis.tick_right()
self.fig1AX.yaxis.set_label_position("right")
self.fig2 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig2AX = self.fig2.figure.add_subplot(111, frameon=False)
self.fig2AX.get_xaxis().set_visible(True)
self.fig2AX.get_yaxis().set_visible(True)
self.fig2AX.yaxis.tick_right()
self.fig2AX.yaxis.set_label_position("right")
self.fig3 = FigureCanvas(Figure(figsize=(5, 4), dpi=100))
self.fig3AX = self.fig3.figure.add_subplot(111, frameon=False)
self.fig3AX.get_xaxis().set_visible(True)
self.fig3AX.get_yaxis().set_visible(True)
self.fig3AX.yaxis.tick_right()
self.fig3AX.yaxis.set_label_position("right")
self.widget = QWidget(self)
grid = QGridLayout()
grid.addWidget(self.fig1, 0, 0, 1, 1)
grid.addWidget(self.fig2, 1, 0, 1, 1)
grid.addWidget(self.fig3, 2, 0, 1, 1)
self.widget.setLayout(grid)
self.setCentralWidget(self.widget)
def plot(self):
x = len(self.raw_data1)
xvals, sigma, optw, covar = fit_gauss(range(x), self.raw_data1)
self.fig1AX.clear()
self.fig1AX.plot(range(len(self.raw_data1)), self.raw_data1, 'k-')
self.fig1AX.plot(range(len(self.raw_data1)), xvals, 'b-', linewidth=2)
self.fig1AX.margins(0, 0)
self.fig1.figure.tight_layout()
self.fig1.draw()
xvals, sigma, optw, covar = fit_gauss(range(x), self.raw_data1)
self.fig2AX.clear()
self.fig2AX.plot(range(len(self.raw_data2)), self.raw_data2, 'k-')
self.fig2AX.plot(range(len(self.raw_data2)), xvals, 'b-', linewidth=2)
self.fig2AX.margins(0, 0)
self.fig2.figure.tight_layout()
self.fig2.draw()
self.fig3AX.clear()
self.fig3AX.plot(range(len(self.raw_data3)), self.raw_data3, 'k-')
self.fig3AX.plot(range(len(self.raw_data3)), xvals, 'b-', linewidth=2)
self.fig3AX.margins(0, 0)
self.fig3.figure.tight_layout()
self.fig3.draw()
if __name__ == '__main__':
app = QApplication([])
window = MainWindow()
window.show()
app.exec_()
Last picture is from here.