微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么在SpaCy文本分类中得到相同的预测? 环境火车数据初始化培训预测

如何解决为什么在SpaCy文本分类中得到相同的预测? 环境火车数据初始化培训预测

我正在Google Colab Pro上工作,这是第一次使用SpaCy进行文本分类转换,并将其与我的Tensorflow模型进行比较。我有10个标签或类别:'fear','positive','disgust','anticipation','anger','sadness','joy','trust','surprise' and 'negative'。火车数据语法如下:

TRAIN_DATA = [('word',{'cats': {'label_1': 0,'label_2': 1,... }}),...]

但是,当我尝试进行预测时,无论要预测什么,我总是得到相同的类别和准确性。

环境

NVIDIA-SMI 450.57 | Driver Version: 418.67 | CUDA Version: 10.1  
Runtime 27.39 GB of available RAM

火车数据

[...
('morals',{'cats': {'anger': 0,'anticipation': 0,'disgust': 0,'fear': 0,'joy': 0,'negative': 1,'positive': 0,'sadness': 0,'surprise': 0,'trust': 0}}),('moral','trust': 0}})
...]

初始化

并将标签类别添加到文本分类

nlp = spacy.load("en_core_web_sm")
category = nlp.create_pipe("textcat",config={"exclusive_classes": True})
nlp.add_pipe(category)

category.add_label("trust")
category.add_label("fear")
category.add_label("disgust")
category.add_label("surprise")
category.add_label("anticipation")
category.add_label("anger")
category.add_label("joy")

培训

# get names of other pipes to disable them during training
n_iter = 10
pipe_exceptions = ["textcat","trf_wordpiecer","trf_tok2vec"]
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
with nlp.disable_pipes(*other_pipes):  # only train textcat
    optimizer = nlp.begin_training()
    
    for i in tqdm(range(n_iter)):
        losses = {}
        random.shuffle(TRAIN_DATA)
    
        for batch in tqdm(minibatch(TRAIN_DATA,size=50)):
            #texts = [nlp(text) for text,entities in batch]
            texts,annotations = zip(*batch)
            #annotations = [{"cats": entities} for text,entities in batch]
            try:
                nlp.update(texts,annotations,sgd=optimizer,losses=losses)
            except:
                pass
        print('\n{}. Losses: {}'.format(i,losses))

100%|██████████| 10/10 [09:26<00:00,56.69s/it]
9. Losses: {'textcat': 0.2062857248383807}

预测

def spacy_prediction(text):
    doc = nlp(u'{}'.format(text))  
    result = doc.cats
    index = np.argmax(result.values())
    res_value = list(result.values())[index] * 100
    res_label = list(result)[index]
    print('Prediction: {},Value: {} %'.format(res_label,round(res_value)))

spacy_prediction('joy')

output: Prediction: trust,Value: 14 %
expected output: Prediction: joy,Value: x%

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。