将Tensorflow代码转换为Pytorch-性能指标大不相同

如何解决将Tensorflow代码转换为Pytorch-性能指标大不相同

我已经将用于时间序列分析的张量流代码转换为pytorch,并且性能差异非常大,实际上pytorch层根本无法解决季节性问题。感觉我一定想念一些重要的东西。

请帮助查找pytorch代码所欠缺的地方,因为学习不符合标准。我注意到,损失值在遇到季节变化而没有学习时会出现较大的跳跃。在相同的层,节点和所有其他事物的情况下,我想象性能会很接近。

# tensorflow code
window_size = 20
batch_size = 32
shuffle_buffer_size = 1000

def windowed_dataset(series,window_size,batch_size,shuffle_buffer):
  dataset = tf.data.Dataset.from_tensor_slices(series)
  dataset = dataset.window(window_size + 1,shift=1,drop_remainder=True)
  dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
  dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1],window[-1]))
  dataset = dataset.batch(batch_size).prefetch(1)
  return dataset

dataset = windowed_dataset(x_train,shuffle_buffer_size)


model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(100,input_shape=[window_size],activation="relu"),tf.keras.layers.Dense(10,tf.keras.layers.Dense(1)
])

model.compile(loss="mse",optimizer=tf.keras.optimizers.SGD(lr=1e-6,momentum=0.9))
model.fit(dataset,epochs=100,verbose=0)
forecast = []
for time in range(len(series) - window_size):
  forecast.append(model.predict(series[time:time + window_size][np.newaxis]))

forecast = forecast[split_time-window_size:]
results = np.array(forecast)[:,0]


plt.figure(figsize=(10,6))

plot_series(time_valid,x_valid)
plot_series(time_valid,results)

tf.keras.metrics.mean_absolute_error(x_valid,results).numpy()

Tensorflow predictions

# pytorch code
window_size = 20
batch_size = 32
shuffle_buffer_size = 1000

class tsdataset(Dataset):
  def __init__(self,series,window_size):
    self.series = series
    self.window_size = window_size
    self.dataset,self.labels = self.preprocess()

  def preprocess(self):
    series = self.series
    final,labels = [],[]
    for i in range(len(series)-self.window_size):
      final.append(np.array(series[i:i+window_size]))
      labels.append(np.array(series[i+window_size]))
    return torch.from_numpy(np.array(final)),torch.from_numpy(np.array(labels))
    
  def __getitem__(self,index):
    # print(self.dataset[index],self.labels[index],index)
    return self.dataset[index],self.labels[index]
  
  def __len__(self):
    return len(self.dataset)

train_dataset = tsdataset(x_train,window_size)
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)

class tspredictor(nn.Module):
  def __init__(self,out1,out2,out3):
    super(tspredictor,self).__init__()
    self.l1 = nn.Linear(window_size,out1)
    self.l2 = nn.Linear(out1,out2)
    self.l3 = nn.Linear(out2,out3)

  def forward(self,seq):
    l1 = F.relu(self.l1(seq))
    l2 = F.relu(self.l2(l1))
    l3 = self.l3(l2)
    return l3

model = tspredictor(20,100,10,1)
loss_function = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=1e-6,momentum=0.9)

for epoch in range(100): 
    for t,l in train_dataloader:
      model.zero_grad()
      tag_scores = model(t)
      loss = loss_function(tag_scores,l)
      loss.backward()
      optimizer.step()
    # print("Epoch is {},loss is {}".format(epoch,loss.data))

forecast = []
for time in range(len(series) - window_size):
    prediction = model(torch.from_numpy(series[time:time + window_size][np.newaxis]))
    forecast.append(prediction)

forecast = forecast[split_time-window_size:]

results = np.array(forecast)

plt.figure(figsize=(10,6))
plot_series(time_valid,results)

Pytorch predictions

要生成数据,可以使用:

def plot_series(time,format="-",start=0,end=None):
    plt.plot(time[start:end],series[start:end],format)
    plt.xlabel("Time")
    plt.ylabel("Value")
    plt.grid(False)

def trend(time,slope=0):
    return slope * time

def seasonal_pattern(season_time):
    """Just an arbitrary pattern,you can change it if you wish"""
    return np.where(season_time < 0.1,np.cos(season_time * 6 * np.pi),2 / np.exp(9 * season_time))

def seasonality(time,period,amplitude=1,phase=0):
    """Repeats the same pattern at each period"""
    season_time = ((time + phase) % period) / period
    return amplitude * seasonal_pattern(season_time)

def noise(time,noise_level=1,seed=None):
    rnd = np.random.RandomState(seed)
    return rnd.randn(len(time)) * noise_level

time = np.arange(10 * 365 + 1,dtype="float32")
baseline = 10
series = trend(time,0.1)  
baseline = 10
amplitude = 40
slope = 0.005
noise_level = 3

# Create the series
series = baseline + trend(time,slope) + seasonality(time,period=365,amplitude=amplitude)
# Update with noise
series += noise(time,noise_level,seed=51)

split_time = 3000
time_train = time[:split_time]
x_train = series[:split_time]
time_valid = time[split_time:]
x_valid = series[split_time:]

解决方法

丢失功能存在广播问题。将损失更改为以下一项即可解决:

loss = loss_function(tag_scores,l.view(-1,1))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)&gt; insert overwrite table dwd_trade_cart_add_inc &gt; select data.id, &gt; data.user_id, &gt; data.course_id, &gt; date_format(
错误1 hive (edu)&gt; insert into huanhuan values(1,&#39;haoge&#39;); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive&gt; show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 &lt;configuration&gt; &lt;property&gt; &lt;name&gt;yarn.nodemanager.res