微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何将 statsmodels 公式和参数保存为文本供以后使用进行预测?

如何解决如何将 statsmodels 公式和参数保存为文本供以后使用进行预测?

python 中的 Statsmodels 提供了一种很好的方法来对各种 R 样式公式进行线性拟合,并根据结果进行预测。进行预测所需的一切都应仅包含在公式和拟合参数列表中。我只想保存那些(例如,在文本配置文件或其他脚本中,所以没有泡菜等)。

那么,仅给定原始公式和拟合参数,是否有任何简单的方法可以重建模型并进行预测?

我也愿意接受 statsmodels 的替代品。

示例:

import numpy as np                                                                                                                                                                                      
import statsmodels.formula.api as sm                                                                                                                                                                    
import pandas as pd                                                                                                                                                                                     
import matplotlib.pyplot as plt                                                                                                                                                                         
                                                                                                                                                                                                        
# Example data                                                                                                                                                                                          
x = np.mgrid[-1:1:10000j]                                                                                                                                                                               
y = np.sin(x*5)                                                                                                                                                                                         
z = 1+2*x+3*y+4*x*y+5*x**2+6*y**2+(np.random.rand(len(x))-.5)                                                                                                                                           
df = pd.DataFrame(data = np.c_[x,y,z],columns = ['x','y','z'])                                                                                                                                         
                                                                                                                                                                                                        
# Fit the data                                                                                                                                                                                          
formula = 'z ~ 1+x*y+I(x**2)+I(y**2)'                                                                                                                                                                   
result = sm.ols(formula,data=df).fit()                                                                                                                                                                 
print(result.summary())                                                                                                                                                                                 
                                                                                                                                                                                                        
# Save formula and params in text freindly way                                                                                                                                                          
saved_formula = 'z ~ ' + '+'.join(result.params.index).replace('Intercept','1') # (is this necessary?)                                                                                                  
saved_params = result.params.values                                                                                                                                                                     
print(saved_formula)                                                                                                                                                                                    
print(saved_params)                                                                                                                                                                                     
                                                                                                                                                                                                        
# Load the formula/params                                                                                                                                                                               
model = result.predict # <-- REPLACE - generate new model using saved_formula,saved_params                                                                                                             
                                                                                                                                                                                                        
# Apply fit to new data                                                                                                                                                                                 
dfnew = pd.DataFrame(data = np.c_[x+1,y],'y'])                                                                                                                                          
znew = model(dfnew)                                                                                                                                                                                     
                                                                                                                                                                                                        
# Plot - just for fun                                                                                                                                                                                   
plt.figure(1);plt.clf()                                                                                                                                                                                 
plt.plot(df.x,z,label='data')                                                                                                                                                                         
plt.plot(df.x,model(df),label='model')                                                                                                                                                                
plt.show()                                                                                                                                                                                              
plt.legend()   

输出(稍后保存/使用的文本):

z ~ 1+x+y+x:y+I(x ** 2)+I(y ** 2)
[1.00604349 2.00219577 3.00889718 3.99726889 5.00210454 5.99458244]

编辑: 我尝试使用假数据创建虚拟模型以获取具有“预测”功能(例如 result.predict())的对象,然后切换拟合参数(例如 result.params = saved_paramsresult.pvalues = saved_params),但不幸的是那没有用。无论模型实际使用什么参数来进行预测似乎都无法公开/可编辑?

Plot of fit data

解决方法

我一直在深入挖掘,看起来 statsmodels 使用 patsy 作为函数语言并生成数组来进行矩阵数学运算。因此,可以像这样使用 patsy 手动重新实现“预测”函数:

# Save formula and params in text freindly way                                                                                        
#saved_formula = 'z ~ ' + '+'.join(result.params.index).replace('Intercept','1') # (is this necessary?)                               
saved_formula = formula                                                                                                               
saved_params = result.params.to_dict()                                                                                                
                                                                                                                                      
# Load the formula/params                                                                                                             
import patsy                                                                                                                          
def predict(data,formula) :                                                                                                          
    formula_rhs = formula.split('~')[1].strip()                                                                                       
    x_data = patsy.dmatrix(formula_rhs,data=df)                                                                                      
    ordered_terms = x_data.design_info.term_names                                                                                     
    ordered_params = list(map(saved_params.get,ordered_terms))                                                                       
    return x_data @ ordered_params                                                                                                  
                                                                                                                                      
model = lambda x : predict(x,saved_formula) 

这可以简化一点,但我担心重构矩阵中的参数顺序可能与原始矩阵不同。不确定这是否真的是一个问题,但我认为我所做的不是。我更改了要保存为字典的参数。生成新的数据矩阵,然后在我知道列的顺序后,我按照该顺序从保存的参数 dict 中提取参数。

预测函数的最后一行是预测步骤。我需要其余部分的唯一原因是使用可配置的线性模型(即不仅仅是模型的参数)。如果你知道你有一个特定的集合模型,你可以简化这个。

使用该表单中的参数可以轻松地往返于文本。例如

import json
saved_params_string = json.dumps(saved_params)
json.loads(saved_params_string)

然后可以将 formulasaved_params_string 字符串保存到配置文件或从配置文件加载或根据需要进行硬编码。

使用 jsonyaml 可以很容易地将实际保存到配置文件中 - 我可能会使用后者。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。