如何解决Flair的SequenceTagger
我正在研究从Flair NLP library继承FlairEmbedding
的自定义类。在此类中,我想使用PyTorch的{{1}}模块来实现模型量化。为此,我需要分几批训练模型以收集统计信息并选择适当的量化参数。该模型将在下游的序列标记器中使用,因此我使用的Flair的torch.quantization
类的参数与在下游任务中使用的参数相同。这是该类的样子:
SequenceTagger
此代码失败,并出现以下错误:
class CustomEmbeddings(FlairEmbeddings):
def __init__(
self,tag_dictionary,tag_type,corpus,mini_batch_size,train_with_dev,# Used for training
model,fine_tune,chars_per_chunk,with_whitespace,tokenized_lm # Base FlairEmbeddings arguments
):
super().__init__(model,tokenized_lm)
self.lm.qconfig = torch.quantization.default_config
torch.quantization.prepare(self.lm.qconfig,inplace=True)
# Small training to gather statistics
tagger = SequenceTagger(hidden_size=256,embeddings=self,tag_dictionary=tag_dictionary,tag_type=tag_type)
trainer = ModelTrainer(tagger,corpus)
------> trainer.train('model',mini_batch_size=mini_batch_size,max_epochs=10,train_with_dev=train_with_dev)
torch.quantization.convert(self.lm,inplace=True)
我不确定问题是否出在我的代码中,还是应该报告给PyTorch或Flair的问题跟踪器。 stacktrace使我认为是这两个库之间的交互失败了,而不是我的代码出现了,特别是因为PyTorch的量化模块仍处于beta中,但我可能会误解。任何有关可能出现的错误的输入将不胜感激。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。