如何解决Trax的AttentionQKV
由Trax实现的AttentionQKV层如下:AttentionQKV
def AttentionQKV(d_feature,n_heads=1,dropout=0.0,mode='train'):
"""Returns a layer that maps (q,k,v,mask) to (activations,mask).
See `Attention` above for further context/details.
Args:
d_feature: Depth/dimensionality of feature embedding.
n_heads: Number of attention heads.
dropout: Probababilistic rate for internal dropout applied to attention
activations (based on query-key pairs) before dotting them with values.
mode: One of `'train'`,`'eval'`,or `'predict'`.
"""
return cb.Serial(
cb.Parallel(
core.Dense(d_feature),core.Dense(d_feature),),PureAttention( # pylint: disable=no-value-for-parameter
n_heads=n_heads,dropout=dropout,mode=mode),)
特别是,三个平行的密集层的目的是什么?该层的输入是q,k,v,掩码。为什么q,k,v穿过一个密集层?
解决方法
此代码段是Attention is all you need论文第5页顶部的等式的实现,该等式在2017年推出了Transformer模型。计算如图2所示:
隐藏状态投射到 h 注意头中,这些注意头并行执行缩放的点积注意。投影可以解释为与头部相关的信息的提取。然后,每个负责人都根据不同的(学习的)标准进行概率检索。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。