如何解决AttributeError: 'DataModuleClass' 对象没有属性 'training_dataset' 2.使用 for 循环
我正在尝试通过编写一个非常简单的 PyTorch Lightning
来学习 DataModuleClass
。在 prepare_data()
和 setup()
之后,我试图检查这些函数是否正常工作。因此,我试图从 training
获取 validation
和 setup()
数据集。但是我收到一个错误
AttributeError: 'DataModuleClass' object has no attribute 'training_dataset'
代码
def prepare_data(self):
x = np.random.uniform(0,10,10)
e = np.random.normal(0,self.sigma,len(x))
# Making target or labels
y = x + e
# Marging x and e for 2 features
X = np.transpose(np.array([x,e]))
# Converting numpy array to Tensor
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(y).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor,self.y_train_tensor)
self.training_dataset = training_dataset
def setup(self):
data = self.training_dataset
self.train_data,self.val_data = random_split(data,[8,2])
return self.train_data,self.val_data
def train_dataloader(self):
return DataLoader(self.train_data)
def val_dataloader(self):
return DataLoader(self.val_data)
obj = DataModuleClass()
print(obj.setup())
你能告诉我为什么我会收到这个错误吗?
解决方法
从代码对我的看法来看。
self.training_dataset
的变量 DataModuleClass
在 prepare_data
中初始化,setup
在第一行需要它。
但是您拨打了 setup
而没有拨打 training_dataset
。
如果每次创建 prepare_data
对象时都希望调用 DataModuleClass
,那么最好将 prepare_data
放在 __init__
中。喜欢
def __init__(self,other_params):
..... all your code previously in __init__
self.prepare_data() # put this in the last line of this function
但是如果您不需要,那么您需要在 prepare_data
setup
obj = DataModuleClass()
obj.prepare_data()
print(obj.setup())
或者将 prepare_data
放在 setup
本身中。
def setup(self):
self.prepare_data()
data = self.training_dataset
self.train_data,self.val_data = random_split(data,[8,2])
return self.train_data,self.val_data
编辑 1:查看 self.train_data
和 self.val_data
的实际值
从 setup
返回的对象是 torch.utils.data.dataset.Subset
。
基本上有两种方法可以获取张量。
1.像列表一样对待它们
train_data,val_data = obj.setup()
print(train_data[0])
2.使用 for 循环
train_data,val_data = obj.setup()
for data in train_data:
print(data)
注意
我不确定您是否会获得张量或 TensorDataset
。如果是后者,则再次使用相同的技巧,例如
train_data,val_data = obj.setup()
train_tensor_data = train_data[0]
print(train_tensor_data[0])
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。