微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

每次调用函数都会弹出KeyError

如何解决每次调用函数都会弹出KeyError

您好,我正在尝试使用 Omniglot 数据集实现 Siamese 神经网络以进行一次性图像识别。实现的初始步骤需要生成具有相同/不同类的配对样本,为此我使用了 Ben Myara's github 中的 ma​​ke_pair 函数,并稍作修改。但是每次调用函数都会弹出keyError,所以想知道是什么原因导致这个错误,下面是我的实现:

import requests
import io
def load_numpy_arr_from_url(url):
"""
Loads a numpy array from surfdrive. 

Input:
url: Download link of dataset 

Outputs:
dataset: numpy array with input features or labels
"""

response = requests.get(url)
response.raise_for_status()

return np.load(io.BytesIO(response.content)) 



# Downloading may take a while..
train_x    =load_numpy_arr_from_url('https://surfdrive.surf.nl/files/index.PHP/s/tvQmLyY7MhVsADb/download')
#Transform bool type to integer
train_data = train_x* 1
train_y = load_numpy_arr_from_url('https://surfdrive.surf.nl/files/index.PHP/s/z234AHrQqx9RVGH/download')

import torch
def make_pairs(data,labels,num=1000):
    digits = {}
    for i,j in enumerate(labels):
        if not j in digits:
            digits[j] = []
        digits[j].append(i)

    pairs,labels_ = [],[]
    for i in range(num):
        if np.random.rand() >= .5: # same digit
            digit = random.choice(range(len(labels+1)))
            d1,d2 = random.choice(digits[digit],size=2,replace=False)
            labels_.append(1)
        else:
            digit1,digit2 = np.random.choice(range(len(labels+1)),replace=False)
            d1,d2 = random.choice(digits[digit1]),np.random.choice(digits[digit2])
            labels_.append(0)
        pairs.append(torch.from_numpy(np.concatenate([data[d1,:],data[d2,:]])).view(1,56,28))
  

  
    return torch.cat(pairs),torch.LongTensor(labels_)

当我尝试使用以下命令调用函数时发生错误

make_pairs(train_data,train_y,5)

这是我得到的回溯错误

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-30-7d53181e46ef> in <module>()
     25 
     26     return torch.cat(pairs),torch.LongTensor(labels_)
---> 27 make_pairs(train_data,5)
     28 #print(a)

<ipython-input-30-7d53181e46ef> in make_pairs(data,num)
     14         if np.random.rand() >= .5: # same digit
     15             digit = random.choice(range(len(labels+1)))
---> 16             print(random.choice(digits[digit],replace=False))
     17             d1,replace=False)
     18             labels_.append(1)

KeyError: 12803

此外,我还尝试在没有 for 循环的情况下实现部分功能,并且一切似乎都在那里正常工作:

import numpy as np
digits = {}
for i,j in enumerate(train_y):
    if not j in digits:
        digits[j] = []
    digits[j].append(i)
pairs,[]
digit = np.random.choice(range(len(train_y)+1)
d1,d2 = np.random.choice(digits[digit],replace=False)
labels_.append(1)
print(torch.from_numpy(np.concatenate([train_data[d1,train_data[d2,28))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。