微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

pytorch:LSTM 输入和输出维度和训练循环

如何解决pytorch:LSTM 输入和输出维度和训练循环

我想编写一个程序来预测比特币的价值,但我不断收到错误,我不知道我做错了什么。

我准备好并处理了数据集并制作了我的神经元网络。但是我很难理解 LSTM 的输入和输出的形状并将我的特征和标签的形状与它们匹配,所以我认为这就是我搞砸的地方,但我不知道如何解决它。也可能是训练循环。

我遇到的错误

索引错误:索引 652 超出维度 0 和大小 1 的范围

我的代码

import torch
import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

df = pd.read_csv("C:/Users/HP/Documents/MASTER GL CLOUD/Python/coin_Bitcoin.csv",index_col = "Date",parse_dates=True)
df=df.iloc[:,6:7]

scaler =  MinMaxScaler(feature_range=(0,1))
data = scaler.fit_transform(np.array(df).reshape(-1,1))

train_data = data[:2000,:]
test_data = data[200:,:]

def creat_dataset(dataset,time_step):
    data_x,data_y = [],[]
    for i in range(len(dataset) - time_step -1):
        a = dataset[i:(i+time_step),0]
        data_x.append(a)
        data_y.append(dataset[i+time_step,0])
        
    return np.array(data_x),np.array(data_y)

time_step = 100
x_train,y_train = creat_dataset(train_data,time_step)
x_test,y_test = creat_dataset(test_data,time_step)

x_train = Variable(torch.Tensor(x_train))
y_train = Variable(torch.Tensor(y_train))

x_test = Variable(torch.Tensor(x_test))
y_test = Variable(torch.Tensor(y_test))


class MyDataset(Dataset):
    def __init__(self,x_train,y_train):
        self.len = x_train.shape[0]
        self.x_data = torch.reshape(x_train,(-1,x_train.shape[0],x_train.shape[2]))
        self.y_data = y_train
    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.len
    
train_data = MyDataset(x_train,y_train)
train_loader = DataLoader(dataset = train_data,batch_size = 256,shuffle = True)

test_data = MyDataset(x_test,y_test)
test_loader = DataLoader(dataset = test_data,shuffle = True)

class Prednet(nn.Module):
        def __init__(self,seq_len,n_layers=2,hidden_dim=200,input_dim=100,output_size=1):
            super(Prednet,self).__init__()
            self.hidden_dim = hidden_dim
            self.n_layers = n_layers
            self.hidden_cell = (torch.zeros(n_layers,hidden_dim),torch.zeros(n_layers,hidden_dim))
            self.lstm = nn.LSTM(input_dim,hidden_dim,n_layers,batch_first=True)
            self.fc =  nn.Linear(hidden_dim,output_size) 
            self.relu = nn.ReLU()       
      
        def forward(self,x ):
            lstm_out,_= self.lstm(x,self.hidden_cell)
            lstm_out = lstm_out[:,-1,:]
            out = self.relu(lstm_out)
            output = self.relu(self.fc(out))
            return output
        
prednet = Prednet(dataset.__len__())  

criterion = torch.nn.MSELoss(size_average = False)   
optimizer = torch.optim.SGD(prednet.parameters(),lr=0.01 ) 

for epoch in range(10):
    for i,data in enumerate(train_loader,0):
       inputs,labels = data
       outputs = prednet(inputs)
       optimizer.zero_grad()
       loss = criterion(outputs,labels)
       loss.backward()
       optimizer.step()
       print("Epoch: %d,loss: %1.5f" % (epoch,loss.item()))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?