微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

通过嵌入层向后传播

如何解决通过嵌入层向后传播

我有一个句子数据集,属于2个类别。我已经根据该数据训练了一个分类器,以将给定的句子分类为两个类别之一。接下来,我想训练一个 Pytorch 生成器(编码器-解码器)模型,以使用预先训练的分类器将给定的句子从1类转换为2类。因此,基本上,这是NLP中的样式转换。这是我模型的骨架:

Encoder:
  Embedding layer --> LSTM [outputs= o,h]

Decoder:
  Embedding layer --> LSTM --> Linear --> Relu --> log_softmax [output= probability for each word in vocab]

Classifier:
  Encoder --> Linear layer1 --> Linear2 --> sigmoid [output = class probability]

Generator:
  Encoder --> Decoder --> topK(1) [outputs = token for each word in generated sentence as floats]

我想做的是,使用来自预训练分类器模型的错误信号训练生成器。此模型是否有效是一个单独的问题(我当然很想听听这里经验丰富的成员的反馈)。但是,这里的主要问题是Generator会以浮点数的形式返回句子标记(单词)的数组,然后应将其传递到一个(冻结的)分类器模型,该模型包含嵌入层作为第一层,该层仅接受Long数据类型。但是将float转换为长张量会破坏梯度历史。根据我对类似嵌入相关问题的了解,通过类型转换无法保留渐变。那我有什么选择呢?任何解决方法?例如,精明的读者会注意到,即使“ topK / argmax”运算也会破坏渐变历史记录,为此,我计划在训练时使用线性层来查找argmax。对于“嵌入”问题也有类似的解决方案吗?我很确定人们会尝试类似的方法,但是我在seq2seq +分类器上找不到任何资源。

注意:我不会发布代码来保持帖子整洁,以便更好地理解。如果需要,我可以提供相关部分。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?