如何解决通过Dataloader枚举时出现KeyError-为什么?
我正在编写一个二进制分类模型,该模型由40个参与者的音频文件组成,并根据他们是否患有语音障碍对其进行分类。音频文件已分为5个第二部分,为避免主题偏见,我将训练/测试/验证集划分为一个主题仅出现在一个集合中(即,参与者ID02不在训练和测试集中都出现) 。当我尝试枚举下面代码中的DataLoader validLoader时,出现以下错误,但我不完全确定为什么会发生此错误。有人有什么建议吗?
KeyError Traceback (most recent call last)
<ipython-input-69-55be99283cf7> in <module>()
----> 1 for i,data in enumerate(valid_loader,0):
2 images,labels = data
3 print("Batch",i,"size:",len(images))
3 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
987 else:
988 del self._task_info[idx]
--> 989 return self._process_data(data)
990
991 def _try_put_index(self):
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self,data)
1012 self._try_put_index()
1013 if isinstance(data,ExceptionWrapper):
-> 1014 data.reraise()
1015 return data
1016
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
393 # (https://bugs.python.org/issue2651),so we work around it.
394 msg = KeyErrorMessage(msg)
--> 395 raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py",line 185,in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py",line 44,in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py",in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-44-245be0a1e978>",line 19,in __getitem__
x = Image.open(self.df['path'][index])
File "/usr/local/lib/python3.6/dist-packages/pandas/core/series.py",line 871,in __getitem__
result = self.index.get_value(self,key)
File "/usr/local/lib/python3.6/dist-packages/pandas/core/indexes/base.py",line 4405,in get_value
return self._engine.get_value(s,k,tz=getattr(series.dtype,"tz",None))
File "pandas/_libs/index.pyx",line 80,in pandas._libs.index.IndexEngine.get_value
File "pandas/_libs/index.pyx",line 90,line 138,in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/hashtable_class_helper.pxi",line 998,in pandas._libs.hashtable.Int64HashTable.get_item
File "pandas/_libs/hashtable_class_helper.pxi",line 1005,in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 36
有人可以告知为什么会这样吗?
from google.colab import drive
drive.mount('/content/drive')
import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import utils
from torch.utils.data import Dataset
from sklearn.metrics import confusion_matrix
from skimage import io,transform,data
from skimage.color import rgb2gray
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import pandas as pd
import numpy as np
import csv
import os
import math
import cv2
root_dir = "/content/drive/My Drive/Read_Text/5_Second_Segments/"
class_names = [
"Parkinsons_Disease","Healthy_Control"
]
def get_meta(root_dir,dirs):
""" Fetches the meta data for all the images and assigns labels.
"""
paths,classes = [],[]
for i,dir_ in enumerate(dirs):
for entry in os.scandir(root_dir + dir_):
if (entry.is_file()):
paths.append(entry.path)
classes.append(i)
return paths,classes
paths,classes = get_meta(root_dir,class_names)
data = {
'path': paths,'class': classes
}
data_df = pd.DataFrame(data,columns=['path','class'])
data_df = data_df.sample(frac=1).reset_index(drop=True) # Shuffles the data
from pandas import option_context
print("Found",len(data_df),"images.")
with option_context('display.max_colwidth',400):
display(data_df.head(100))
class Audio(Dataset):
def __init__(self,df,transform=None):
"""
Args:
image_dir (string): Directory with all the images
df (DataFrame object): Dataframe containing the images,paths and classes
transform (callable,optional): Optional transform to be applied
on a sample.
"""
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self,index):
# Load image from path and get label
x = Image.open(self.df['path'][index])
try:
x = x.convert('RGB') # To deal with some grayscale images in the data
except:
pass
y = torch.tensor(int(self.df['class'][index]))
if self.transform:
x = self.transform(x)
return x,y
def compute_img_mean_std(image_paths):
"""
Author: @xinruizhuang. Computing the mean and std of three channel on the whole dataset,first we should normalize the image from 0-255 to 0-1
"""
img_h,img_w = 224,224
imgs = []
means,stdevs = [],[]
for i in tqdm(range(len(image_paths))):
img = cv2.imread(image_paths[i])
img = cv2.resize(img,(img_h,img_w))
imgs.append(img)
imgs = np.stack(imgs,axis=3)
print(imgs.shape)
imgs = imgs.astype(np.float32) / 255.
for i in range(3):
pixels = imgs[:,:,:].ravel() # resize to one row
means.append(np.mean(pixels))
stdevs.append(np.std(pixels))
means.reverse() # BGR --> RGB
stdevs.reverse()
print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))
return means,stdevs
norm_mean,norm_std = compute_img_mean_std(paths)
data_transform = transforms.Compose([
transforms.Resize(256),transforms.CenterCrop(256),transforms.ToTensor(),transforms.Normalize(norm_mean,norm_std),])
unique_users = data_df['path'].str[-20:-16].unique()
train_users,test_users = np.split(np.random.permutation(unique_users),[int(0.8*len(unique_users))])
df_train = data_df[data_df['path'].str[-20:-16].isin(train_users)]
test_data_df = data_df[data_df['path'].str[-20:-16].isin(test_users)]
train_unique_users = df_train['path'].str[-20:-16].unique()
train_users,validate_users = np.split(np.random.permutation(train_unique_users),[int(0.875*len(train_unique_users))])
train_data_df = df_train[df_train['path'].str[-20:-16].isin(train_users)]
valid_data_df = df_train[df_train['path'].str[-20:-16].isin(validate_users)]
ins_dataset_train = Audio(
df=train_data_df,transform=data_transform,)
ins_dataset_valid = Audio(
df=valid_data_df,)
ins_dataset_test = Audio(
df=test_data_df,)
train_loader = torch.utils.data.DataLoader(
ins_dataset_train,batch_size=8,shuffle=True,num_workers=2
)
test_loader = torch.utils.data.DataLoader(
ins_dataset_test,batch_size=16,num_workers=2
)
valid_loader = torch.utils.data.DataLoader(
ins_dataset_valid,num_workers=2
)
//(This is where the error is occurring.)
for i,0):
images,labels = data
print("Batch",len(images))
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。