微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 SVR 预测股票价格:基于时间序列的问题

如何解决使用 SVR 预测股票价格:基于时间序列的问题

我正在尝试使用 SVR 预测股票价格(调整收盘价)。我能够为训练数据训练模型,但我收到测试数据错误。 2014 年至 2018 年的训练数据存储在 dataframe df 中,而 2019 年至今的测试数据存储在 dataframe test_df 中。代码如下:

import pandas as pd 
import pandas_datareader.data as web
import datetime
import numpy as np 
from matplotlib import style

# Get the stock data using yahoo API:
style.use('ggplot')

# get 2014-2018 data to train our model
start = datetime.datetime(2014,1,1)
end = datetime.datetime(2018,12,30)
df = web.DataReader("TSLA",'yahoo',start,end) 

# get 2019 data to test our model on 
start = datetime.datetime(2019,1)
end = datetime.date.today()
test_df = web.DataReader("TSLA",end) 


# sort by date
df = df.sort_values('Date')
test_df = test_df.sort_values('Date')

# fix the date 
df.reset_index(inplace=True)
df.set_index("Date",inplace=True)
test_df.reset_index(inplace=True)
test_df.set_index("Date",inplace=True)

df.tail()

Sample data

# Converting dates

import matplotlib.dates as mdates

# change the dates into ints for training 
dates_df = df.copy()
dates_df = dates_df.reset_index()

# Store the original dates for plotting the predicitons
org_dates = dates_df['Date']

# convert to ints
dates_df['Date'] = dates_df['Date'].map(mdates.date2num)

dates_df.tail()

enter image description here

# Use sklearn support vector regression to predicit our data:
from sklearn.svm import SVR

dates = dates_df['Date'].to_numpy()
prices = df['Adj Close'].to_numpy()

#Convert to 1d Vector
dates = np.reshape(dates,(len(dates),1))
prices = np.reshape(prices,(len(prices),1))

svr_rbf = SVR(kernel= 'rbf',C= 1e3,gamma= 0.1)
svr_rbf.fit(dates,prices)


plt.figure(figsize = (12,6))
plt.plot(dates,prices,color= 'black',label= 'Data')
plt.plot(org_dates,svr_rbf.predict(dates),color= 'red',label= 'RBF model') 
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.show()

enter image description here

对于训练数据,它可以正常工作到这里。接下来,我如何预测测试数据 (test_df)。

解决方法

按照您的约定,它应该如下所示:

# change the dates into ints for training 
test_dates_df = test_df.copy()
test_dates_df = test_dates_df.reset_index()

# Store the original dates for plotting the predicitons
test_org_dates = test_dates_df['Date']

# convert to ints
test_dates_df['Date'] = test_dates_df['Date'].map(mdates.date2num)

test_dates = test_dates_df['Date'].to_numpy()
test_prices = test_df['Adj Close'].to_numpy()

#Convert to 1d Vector
test_dates = np.reshape(test_dates,(len(test_dates),1))
test_prices = np.reshape(test_prices,(len(test_prices),1))

# Predict on unseen test data
y_hat_test = svr_rbf.predict(test_dates)

# Visualize predictions against real values
plt.figure(figsize = (12,6))
plt.plot(test_dates,test_prices,color= 'black',label= 'Data')
plt.plot(test_org_dates,y_hat_test,color= 'red',label= 'RBF model (test)') 
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.show()

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。