微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何计算张量流自动编码器生成的批次的每个示例的潜在表示之间的余弦距离?

如何解决如何计算张量流自动编码器生成的批次的每个示例的潜在表示之间的余弦距离?

我想在keras / tensorflow中添加一个自定义损失项,它是由自动编码器生成的当前批处理的所有深度编码/潜在表示之间的平均余弦距离。因此,我在编码末尾添加一个额外的输出,即潜在的表示形式。然后,我要应用计算批次的每个示例与另一个示例之间的余弦距离。如果我的批处理大小为32,则需要将其计算为32 * 32。

我从以下自定义损失函数开始:

def my_loss_fn2_1(y_true,y_pred):
     cosine_loss = tf.keras.losses.Cosinesimilarity(axis=1)
     loss = cosine_loss(y_pred,y_pred)
return loss

其中y_pred是深度为[?,10000]的深度嵌入,其中?是批次尺寸,是因为?我无法复制该批次的示例(32),因此我可以计算该批次之间的距离。此额外的损失项应作为调节器。因此,主要问题是计算彼此之间的距离。我怎样才能做到这一点?使用tf.repeat,tf.tile,tf.gather ...时,在编译模型时遇到问题,因为批处理维度为none /?因此,我无法重复这些示例。还有其他方法解决方案可以做到这一点吗?

谢谢!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?