微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tensorflow的AdditiveAttention实现没有权重

如何解决Tensorflow的AdditiveAttention实现没有权重

我试图了解如何在Tensorflow / Keras中实现新的(自定义)图层。 Bahdanau的“添加剂注意”似乎很简单。机制的一部分是这样的:

enter image description here

这是implemented by tensorflow。但是,看一下代码,我似乎找不到tanh函数中应该使用的权重。这使我觉得我对Tensorflow中的图层了解不足。另一个结论是tensorflow在这里没有实现权重。这似乎不太可能。

我想解释一下Tensorflow如何将此机制实现为自定义层。

Tensorflow的AdditiveAttention的子类如下:

 def __init__(self,use_scale=True,**kwargs):
    super(AdditiveAttention,self).__init__(**kwargs)
    self.use_scale = use_scale

  def build(self,input_shape):
    v_shape = tensor_shape.TensorShape(input_shape[1])
    dim = v_shape[-1]
    if isinstance(dim,tensor_shape.Dimension):
      dim = dim.value
    if self.use_scale:
      self.scale = self.add_weight(
          name='scale',shape=[dim],initializer=init_ops.glorot_uniform_initializer(),dtype=self.dtype,trainable=True)
    else:
      self.scale = None
    super(AdditiveAttention,self).build(input_shape)

唯一的权重是self.scale。稍后,它在_calculate_scores(query,key)中与tanh函数一起使用:

math_ops.reduce_sum(scale * math_ops.tanh(q_reshaped + k_reshaped),axis=-1)

如Bahdanau的得分方​​程所示,可训练的权重应乘以查询q_reshaped)和键(k_reshaped)?

q_reshapedk_reshaped内容传递到call()函数中,如下所示:

def call(self,inputs,mask=None,training=None):
    self._validate_call_args(inputs=inputs,mask=mask)
    q = inputs[0]
    v = inputs[1]
    k = inputs[2] if len(inputs) > 2 else v
    q_mask = mask[0] if mask else None
    v_mask = mask[1] if mask else None
    scores = self._calculate_scores(query=q,key=k)
...

权重应在调用call()之后创建。 (call(),调用build())。所以在我看来,查询和键没有加权。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。