如何解决scipy.optimize.curve_fit ValueError:具有多个元素的数组的真值不明确
我正在尝试使用 scipy.optimize.curve_fit 将 sigmoidal 曲线拟合到我的数据集,但出现以下错误:
Traceback (most recent call last):
File "<input>",line 1,in <module>
File "Y:\WORK\code\venv\lib\site-packages\scipy\optimize\minpack.py",line 784,in curve_fit
res = leastsq(func,p0,Dfun=jac,full_output=1,**kwargs)
File "Y:\WORK\code\venv\lib\site-packages\scipy\optimize\minpack.py",line 410,in leastsq
shape,dtype = _check_func('leastsq','func',func,x0,args,n)
File "Y:\WORK\code\venv\lib\site-packages\scipy\optimize\minpack.py",line 24,in _check_func
res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
File "Y:\WORK\code\venv\lib\site-packages\scipy\optimize\minpack.py",line 484,in func_wrapped
return func(xdata,*params) - ydata
File "<input>",line 7,in func
File "Y:\WORK\code\venv\lib\site-packages\scipy\integrate\quadpack.py",line 348,in quad
flip,a,b = b < a,min(a,b),max(a,b)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我试图拟合的 sigmoidal 函数包含一个从负无穷大到自变量函数的积分。这是我的代码:
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.integrate import quad
import numpy as np
import math
x_data = np.array([ 2.5,7.5,12.5,17.5,22.5,27.5,32.5,37.5,42.5,47.5,52.5,57.5,62.5,67.5,72.5,77.5])
y_data = np.array([0.05,0.09,0.13,0.15,0.2,0.35,0.45,0.53,0.68,0.8,0.9,0.92,0.99,1,0.95,0.97])
#Constructing the sigmoidal function
def integrand(x):
return np.exp(-np.square(x)/2)
def func(x,b):
t = (x-a)/(b*a)
integral = quad(integrand,-math.inf,t)[0]
return math.sqrt(1/(2*math.pi))*integral
# Initial guess for parameters a,b
initialGuess = [35,0.25]
# Perform the curve-fit
popt,pcov = curve_fit(func,x_data,y_data,initialGuess)
问题似乎来自于积分部分。在其他类似的帖子中,他们的函数包含一个布尔参数,如果自变量超过一个值,则提供另一个函数。但在我的情况下不是。我很困惑...
解决方法
感谢 @hpaulj
:quad
函数只接受标量,而我的 t
是一个数组。我已将功能更改如下,错误消失了
def func(x,a,b):
sigmoid_arr = np.array([])
for i in x:
t = (i- a)/(b*a)
integral = quad(integrand,-math.inf,t)[0]
sigmoid = math.sqrt(1/(2*math.pi))*integral
sigmoid_arr = np.append(sigmoid_arr,sigmoid)
return sigmoid_arr
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。