微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow 梯度反向传播与 tf.repeat

如何解决Tensorflow 梯度反向传播与 tf.repeat

简介

我正在尝试在 TensorFlow 中实现 Set Transformer architecture

研究人员提出的模块之一是具有多头注意力的池化 (PMA) 模块。如果我的理解是正确的,这个模块使用一个可训练种子向量。该向量用作多头注意力模块的查询,键和值是要汇集的向量设置

authors implementation of this module(使用 PyTorch)中,我们可以看到这个向量在前向传递期间重复以满足批次维度:

def forward(self,X):
        return self.mab(self.S.repeat(X.size(0),1,1),X)

因此,我使用 TensorFlow(作为 tf.keras.layers.Layer)实现的前向传递是:

def call(self,inputs,**kwargs):
    q = inputs
    # !!! Is the repeat operation allowing the back-propagation of gradient ?
    s = tf.expand_dims(self.seed_vector,axis=0)
    s = tf.repeat(s,tf.shape(q)[0],axis=0)
    return self.mab((s,q))

self.mab 在两种实现中都是论文中定义的多头注意力块)

在我的实现中,self.seed_vector添加到层的可训练权重:

def build(self,input_shape):
    # [...]
    self.seed_vector = self.add_weight(
        shape=(1,input_shape[2]),initializer="random_normal",dtype=tf.float32,trainable=True
    )

我还没有找到一种更直接的方法来执行种子向量重复以满足批量维度约束,但代码可以编译并且模型能够被训练。

问题

我的问题是,当我使用 PMA 模块而不是简单的平均池化层时,我注意到在训练过程中模型训练/验证损失非常不稳定。因此,我想知道我对 PMA 的实施是否正确。更具体地说,由于必须使用模型权重来学习种子向量,我很好奇 tf.repeat 操作对与此可训练向量权重相关的误差梯度计算的影响. TensorFlow 在计算误差梯度时可以“取消批处理”这个向量并在反向传递中正确更新向量吗?

我在作者的 PyTorch 实现的训练代码中没有注意到与此部分相关的任何内容。但是 PyTorch 梯度计算和反向传播可能与 TensorFlow 中实现的不同,我在 TensorFlow 文档中找不到与 tf.repeat 操作的这种潜在影响相关的任何内容。如果有人对这件事有想法,我将不胜感激!

注意事项

我知道很多因素都可能导致意想不到的结果,当然,我正在调查这些因素。但我对模型的这一特定部分非常怀疑,因为它的引入似乎严重阻碍了训练行为。如果实现没问题,原因可能只是这个模块的使用对我的任务/数据来说不是一个好主意。我不是在寻求关于整个模型实现的帮助。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。