如何解决为什么嵌入维度必须可以被 MultiheadAttention 中的头数整除?
我正在学习 Transformer。这是 MultiheadAttention 的 pytorch 文档。在他们的 implementation 中,我看到有一个约束:
assert self.head_dim * num_heads == self.embed_dim,"embed_dim must be divisible by num_heads"
为什么需要约束:embed_dim must be divisible by num_heads?
如果我们回到等式
假设:
Q
、K
、V
是 n x emded_dim
矩阵;所有的权重矩阵 W
都是 emded_dim x head_dim
,
那么,concat [head_i,...,head_h]
将是一个 n x (num_heads*head_dim)
矩阵;
W^O
大小为 (num_heads*head_dim) x embed_dim
[head_i,head_h] * W^O
将成为 n x embed_dim
输出
我不知道为什么我们需要 embed_dim must be divisible by num_heads
。
假设我们有 num_heads=10000
,结果是一样的,因为矩阵-矩阵乘积会吸收这些信息。
解决方法
当您有一个 seq_len x emb_dim
序列(即 20 x 8
)并且您想使用 num_heads=2
时,该序列将沿 emb_dim
维度拆分。因此,您会得到两个 20 x 4
序列。您希望每个头部都具有相同的形状,如果 emb_dim
不能被 num_heads
整除,这将不起作用。以序列 20 x 9
和 num_heads=2
为例。然后你会得到 20 x 4
和 20 x 5
,它们不是同一个维度。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。