如何解决当我使用BERT预测所有令牌时自我索引超出范围
def load_predict_BERT(self,masked_text):
"""
Look for the [MASK] tokens and then attempts to predict the original value of the masked words
:param masked_text: str
Text containing [MASK] tokens for each word to predict
:return: predictions:
:return: MASKIDS: list
"""
# Load,train and predict using pre-trained model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_text = tokenizer.tokenize(masked_text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
MASKIDS = [i for i,e in enumerate(tokenized_text) if e == '[MASK]']
# Create the segments tensors
segs = [i for i,e in enumerate(tokenized_text) if e == "."]
segments_ids = []
prev = -1
for k,s in enumerate(segs):
segments_ids = segments_ids + [k] * (s - prev)
prev = s
segments_ids = segments_ids + [len(segs)] * (len(tokenized_text) - len(segments_ids))
segments_tensors = torch.tensor([segments_ids])
# Prepare Torch inputs
tokens_tensor = torch.tensor([indexed_tokens])
# Load pre-trained model
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.resize_token_embeddings(len(tokenizer))
# Predict all tokens
with torch.no_grad():
predictions = model(tokens_tensor,segments_tensors)
return predictions,MASKIDS,tokenizer
我收到此错误,但无法修复。我是这种方法的初学者,但我希望有人可以帮助我。
Traceback (most recent call last):
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/IPython/core/interactiveshell.py",line 3417,in run_code
exec(code_obj,self.user_global_ns,self.user_ns)
File "<ipython-input-9-c1e97a1e3497>",line 3,in <module>
text = doc.extract_text(lang='la_Lat')
File "<ipython-input-2-669f80886eff>",line 79,in extract_text
prediction,tokenizer = self.load_predict_BERT(masked_text=masked_text)
File "<ipython-input-2-669f80886eff>",line 170,in load_predict_BERT
predictions = model(tokens_tensor,segments_tensors)
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",line 722,in _call_impl
result = self.forward(*input,**kwargs)
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/pytorch_pretrained_bert/modeling.py",line 861,in forward
sequence_output,_ = self.bert(input_ids,token_type_ids,attention_mask,File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",line 730,in forward
embedding_output = self.embeddings(input_ids,token_type_ids)
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",line 269,in forward
token_type_embeddings = self.token_type_embeddings(token_type_ids)
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",**kwargs)
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/sparse.py",line 124,in forward
return F.embedding(
File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/functional.py",line 1814,in embedding
return torch.embedding(weight,input,padding_idx,scale_grad_by_freq,sparse)
IndexError: index out of range in self
我知道令牌的大小有问题,但我不知道为什么。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。