微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么 SciPy 的 curve_fit 不能找到这个高阶高斯函数的协方差/给我有意义的参数?

如何解决为什么 SciPy 的 curve_fit 不能找到这个高阶高斯函数的协方差/给我有意义的参数?

这是一些最小的代码

from scipy.optimize import curve_fit
xdata = [16.530468600170202,16.86156794563677,17.19266729110334,17.523766636569913,17.854865982036483,18.18596532750305,18.51706467296962,18.848164018436194,19.179263363902763,19.510362709369332]
ydata = [394,1121,1173,1196,1140,1158,1160,1046,416]
#function i'm trying to fit the data to (higher order gaussian function)
def gauss(x,sig,n,c,x_o):
    return c*np.exp(-(x-x_o)**n/(2*sig**n))
popt = curve_fit(gauss,xdata,ydata)
#attempt at printing out parameters
print(popt)

当我尝试执行代码时,收到此错误消息:

ExCurveFit:9: RuntimeWarning: invalid value encountered in power
  return c*np.exp(-(x-x_o)**n/(2*sig**n))
C:\Users\dsneh\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\python38\site-packages\scipy\optimize\minpack.py:828: OptimizeWarning: Covariance of the parameters Could not be estimated
  warnings.warn('Covariance of the parameters Could not be estimated',(array([nan,nan,nan]),array([[inf,inf,inf],[inf,inf]]))

我已经看到我应该忽略第一个,但也许这是一个问题。第二个更令人担忧,我显然希望参数有更合理的值。我尝试添加参数作为猜测([2,3,1000,17] 供参考),但没有帮助。

解决方法

我相信您遇到了这个问题,因为 curve_fit 也在测试 n 的非整数值,在这种情况下,您的函数 gaussx<x_o 时返回复数值.

我相信通过暴力破解每个整数 n 并为每个 sig,c,x_o 找到最佳参数 n 会更容易。 例如,如果您考虑 n = 0,1,2,... 最多可能是 50,那么 n 的选项很少,而暴力破解实际上是一种不错的方法。 查看您的数据,最好仅将 n 视为 2 的倍数(除非您的其他数据看起来像 n 可能是一个奇数)。

此外,您可能应该为其他参数引入一些界限,例如拥有 sig>0 会更好,并且您可以安全地设置 c>0(也就是说,如果您的所有数据看起来像您在问题中包含的最少数据)。

这是我所做的:

p0 = [1,1200,18] # initial guess for sig,x_o
bounds = ([0.01,100,10],[1000,2500,200]) # (lower,upper) bounds for sig,x_o
min_err = np.inf # error to minimize to find the best value for n

for n in range(0,41,2): # n = 0,4,6,...,40

    def gauss(x,sig,x_o,n=n):
        return c*np.exp(-(x-x_o)**n/(2*sig**n))
    
    popt,pcov = curve_fit(gauss,xdata,ydata,p0=p0,bounds=bounds)

    err = np.sqrt(np.diag(pcov)).sum()
    if err < min_err:
        min_err,best_n,best_popt,best_pcov = err,n,popt,pcov

print(min_error,best_pcov,sep='\n')

import matplotlib.pyplot as plt
plt.plot(xdata,ydata)
plt.plot(xdata,gauss(xdata,*best_popt,n=best_n))
plt.show()

我得到 best_n = 10best_popt = [1.38,1173.52,18.02]

这是结果图:(蓝线是数据,橙色线是高斯拟合)

best fit

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。