如何解决将Tensorflow代码转换为Pytorch-性能指标大不相同
我已经将用于时间序列分析的张量流代码转换为pytorch,并且性能差异非常大,实际上pytorch层根本无法解决季节性问题。感觉我一定想念一些重要的东西。
请帮助查找pytorch代码所欠缺的地方,因为学习不符合标准。我注意到,损失值在遇到季节变化而没有学习时会出现较大的跳跃。在相同的层,节点和所有其他事物的情况下,我想象性能会很接近。
# tensorflow code
window_size = 20
batch_size = 32
shuffle_buffer_size = 1000
def windowed_dataset(series,window_size,batch_size,shuffle_buffer):
dataset = tf.data.Dataset.from_tensor_slices(series)
dataset = dataset.window(window_size + 1,shift=1,drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1],window[-1]))
dataset = dataset.batch(batch_size).prefetch(1)
return dataset
dataset = windowed_dataset(x_train,shuffle_buffer_size)
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(100,input_shape=[window_size],activation="relu"),tf.keras.layers.Dense(10,tf.keras.layers.Dense(1)
])
model.compile(loss="mse",optimizer=tf.keras.optimizers.SGD(lr=1e-6,momentum=0.9))
model.fit(dataset,epochs=100,verbose=0)
forecast = []
for time in range(len(series) - window_size):
forecast.append(model.predict(series[time:time + window_size][np.newaxis]))
forecast = forecast[split_time-window_size:]
results = np.array(forecast)[:,0]
plt.figure(figsize=(10,6))
plot_series(time_valid,x_valid)
plot_series(time_valid,results)
tf.keras.metrics.mean_absolute_error(x_valid,results).numpy()
# pytorch code
window_size = 20
batch_size = 32
shuffle_buffer_size = 1000
class tsdataset(Dataset):
def __init__(self,series,window_size):
self.series = series
self.window_size = window_size
self.dataset,self.labels = self.preprocess()
def preprocess(self):
series = self.series
final,labels = [],[]
for i in range(len(series)-self.window_size):
final.append(np.array(series[i:i+window_size]))
labels.append(np.array(series[i+window_size]))
return torch.from_numpy(np.array(final)),torch.from_numpy(np.array(labels))
def __getitem__(self,index):
# print(self.dataset[index],self.labels[index],index)
return self.dataset[index],self.labels[index]
def __len__(self):
return len(self.dataset)
train_dataset = tsdataset(x_train,window_size)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
class tspredictor(nn.Module):
def __init__(self,out1,out2,out3):
super(tspredictor,self).__init__()
self.l1 = nn.Linear(window_size,out1)
self.l2 = nn.Linear(out1,out2)
self.l3 = nn.Linear(out2,out3)
def forward(self,seq):
l1 = F.relu(self.l1(seq))
l2 = F.relu(self.l2(l1))
l3 = self.l3(l2)
return l3
model = tspredictor(20,100,10,1)
loss_function = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=1e-6,momentum=0.9)
for epoch in range(100):
for t,l in train_dataloader:
model.zero_grad()
tag_scores = model(t)
loss = loss_function(tag_scores,l)
loss.backward()
optimizer.step()
# print("Epoch is {},loss is {}".format(epoch,loss.data))
forecast = []
for time in range(len(series) - window_size):
prediction = model(torch.from_numpy(series[time:time + window_size][np.newaxis]))
forecast.append(prediction)
forecast = forecast[split_time-window_size:]
results = np.array(forecast)
plt.figure(figsize=(10,6))
plot_series(time_valid,results)
要生成数据,可以使用:
def plot_series(time,format="-",start=0,end=None):
plt.plot(time[start:end],series[start:end],format)
plt.xlabel("Time")
plt.ylabel("Value")
plt.grid(False)
def trend(time,slope=0):
return slope * time
def seasonal_pattern(season_time):
"""Just an arbitrary pattern,you can change it if you wish"""
return np.where(season_time < 0.1,np.cos(season_time * 6 * np.pi),2 / np.exp(9 * season_time))
def seasonality(time,period,amplitude=1,phase=0):
"""Repeats the same pattern at each period"""
season_time = ((time + phase) % period) / period
return amplitude * seasonal_pattern(season_time)
def noise(time,noise_level=1,seed=None):
rnd = np.random.RandomState(seed)
return rnd.randn(len(time)) * noise_level
time = np.arange(10 * 365 + 1,dtype="float32")
baseline = 10
series = trend(time,0.1)
baseline = 10
amplitude = 40
slope = 0.005
noise_level = 3
# Create the series
series = baseline + trend(time,slope) + seasonality(time,period=365,amplitude=amplitude)
# Update with noise
series += noise(time,noise_level,seed=51)
split_time = 3000
time_train = time[:split_time]
x_train = series[:split_time]
time_valid = time[split_time:]
x_valid = series[split_time:]
解决方法
丢失功能存在广播问题。将损失更改为以下一项即可解决:
loss = loss_function(tag_scores,l.view(-1,1))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。