如何解决Pytorch N - Beats 模型抛出错误:'str' 对象没有属性 '__name__'
我正在尝试在 colab 中复制 pytorch 的 N - Beats 模型。我将相同的代码从 https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/ar.html 复制到 colab notebook。训练单元出现错误。
import os
import warnings
warnings.filterwarnings("ignore")
os.chdir("../../..")
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
import torch
from pytorch_forecasting import Baseline,NBeats,TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
from pytorch_forecasting.metrics import SMAPE
data = generate_ar_data(seasonality=10.0,timesteps=400,n_series=100,seed
= 42)
data["static"] = 2
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx,"D")
data.head()
# create dataset and dataloaders
max_encoder_length = 60
max_prediction_length = 20
training_cutoff = data["time_idx"].max() - max_prediction_length
context_length = max_encoder_length
prediction_length = max_prediction_length
training = TimeSeriesDataSet(
data[lambda x: x.time_idx <= training_cutoff],time_idx="time_idx",target="value",categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},group_ids=["series"],# only unknown variable is "value" - and N-Beats can also not take any additional variables
time_varying_unknown_reals=["value"],max_encoder_length=context_length,max_prediction_length=prediction_length,)
validation = TimeSeriesDataSet.from_dataset(training,data,min_prediction_idx=training_cutoff + 1)
batch_size = 128
train_dataloader = training.to_dataloader(train=True,batch_size=batch_size,num_workers=0)
val_dataloader = validation.to_dataloader(train=False,num_workers=0)
错误是:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-67-db4b0ef13391> in <module>()
25 net,26 train_dataloader=train_dataloader,---> 27 val_dataloaders=val_dataloader,28 )
30 frames
/usr/local/lib/python3.7/dist-packages/yaml/representer.py in represent_object(self,data)
329 if dictitems is not None:
330 dictitems = dict(dictitems)
--> 331 if function.__name__ == '__newobj__':
332 function = args[0]
333 args = args[1:]
AttributeError: 'str' object has no attribute '__name__'
解决方法
将 pytorch-lightning 从 1.2.1 降级到 1.1.8 为我解决了这个问题。
,@PVJ 的回答对我有用。为了完整起见,您可以通过以下方式降级 pytorch_lightning
:
pip install --upgrade pytorch_lightning==1.1.8
,
最近遇到了类似的问题,发现将 pandas 降级到 1.2.5 解决了这个问题
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。