微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

为什么嵌入维度必须可以被 MultiheadAttention 中的头数整除?

如何解决为什么嵌入维度必须可以被 MultiheadAttention 中的头数整除?

我正在学习 Transformer。这是 MultiheadAttention 的 pytorch 文档。在他们的 implementation 中,我看到有一个约束:

 assert self.head_dim * num_heads == self.embed_dim,"embed_dim must be divisible by num_heads"

为什么需要约束:embed_dim must be divisible by num_heads? 如果我们回到等式

MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOwhereheadi​=Attention(QWiQ​,KWiK​,VWiV​)

假设: QKVn x emded_dim 矩阵;所有的权重矩阵 W 都是 emded_dim x head_dim,

那么,concat [head_i,...,head_h] 将是一个 n x (num_heads*head_dim) 矩阵;

W^O 大小为 (num_heads*head_dim) x embed_dim

[head_i,head_h] * W^O 将成为 n x embed_dim 输出

我不知道为什么我们需要 embed_dim must be divisible by num_heads

假设我们有 num_heads=10000,结果是一样的,因为矩阵-矩阵乘积会吸收这些信息。

解决方法

当您有一个 seq_len x emb_dim 序列(即 20 x 8)并且您想使用 num_heads=2 时,该序列将沿 emb_dim 维度拆分。因此,您会得到两个 20 x 4 序列。您希望每个头部都具有相同的形状,如果 emb_dim 不能被 num_heads 整除,这将不起作用。以序列 20 x 9num_heads=2 为例。然后你会得到 20 x 420 x 5,它们不是同一个维度。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。