微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

当我使用BERT预测所有令牌时自我索引超出范围

如何解决当我使用BERT预测所有令牌时自我索引超出范围

当我尝试使用具有以下功能的BERT预测文本中的所有标记时:


        def load_predict_BERT(self,masked_text):
            """
            Look for the [MASK] tokens and then attempts to predict the original value of the masked words

            :param masked_text: str
                Text containing [MASK] tokens for each word to predict
            :return: predictions:
            :return: MASKIDS: list
            """
            # Load,train and predict using pre-trained model
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            tokenized_text = tokenizer.tokenize(masked_text)
            indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
            MASKIDS = [i for i,e in enumerate(tokenized_text) if e == '[MASK]']
            # Create the segments tensors
            segs = [i for i,e in enumerate(tokenized_text) if e == "."]
            segments_ids = []
            prev = -1
            for k,s in enumerate(segs):
                segments_ids = segments_ids + [k] * (s - prev)
                prev = s
            segments_ids = segments_ids + [len(segs)] * (len(tokenized_text) - len(segments_ids))
            segments_tensors = torch.tensor([segments_ids])
            # Prepare Torch inputs
            tokens_tensor = torch.tensor([indexed_tokens])
            # Load pre-trained model
            model = BertForMaskedLM.from_pretrained('bert-base-uncased')
            model.resize_token_embeddings(len(tokenizer))
            # Predict all tokens
            with torch.no_grad():
                predictions = model(tokens_tensor,segments_tensors)
            return predictions,MASKIDS,tokenizer

我收到此错误,但无法修复。我是这种方法的初学者,但我希望有人可以帮助我。

Traceback (most recent call last):
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/IPython/core/interactiveshell.py",line 3417,in run_code
    exec(code_obj,self.user_global_ns,self.user_ns)
  File "<ipython-input-9-c1e97a1e3497>",line 3,in <module>
    text = doc.extract_text(lang='la_Lat')
  File "<ipython-input-2-669f80886eff>",line 79,in extract_text
    prediction,tokenizer = self.load_predict_BERT(masked_text=masked_text)
  File "<ipython-input-2-669f80886eff>",line 170,in load_predict_BERT
    predictions = model(tokens_tensor,segments_tensors)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",line 722,in _call_impl
    result = self.forward(*input,**kwargs)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/pytorch_pretrained_bert/modeling.py",line 861,in forward
    sequence_output,_ = self.bert(input_ids,token_type_ids,attention_mask,File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",line 730,in forward
    embedding_output = self.embeddings(input_ids,token_type_ids)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",line 269,in forward
    token_type_embeddings = self.token_type_embeddings(token_type_ids)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/module.py",**kwargs)
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/modules/sparse.py",line 124,in forward
    return F.embedding(
  File "/Users/AppleDreeaMz/opt/anaconda3/envs/OCRNLP_env/lib/python3.8/site-packages/torch/nn/functional.py",line 1814,in embedding
    return torch.embedding(weight,input,padding_idx,scale_grad_by_freq,sparse)
IndexError: index out of range in self

我知道令牌的大小有问题,但我不知道为什么。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。