微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

pytorch nn.EmbeddingBag 中的偏移量是什么意思?

如何解决pytorch nn.EmbeddingBag 中的偏移量是什么意思?

我知道偏移量在有两个数字时是什么意思,但是当超过两个数字时是什么意思,例如:

weight = torch.FloatTensor([[1,2,3],[4,5,6]])
embedding_sum = nn.EmbeddingBag.from_pretrained(weight,mode='sum')
print(list(embedding_sum.parameters()))
input = torch.LongTensor([0,1])
offsets = torch.LongTensor([0,1,1])

print(embedding_sum(input,offsets))

结果是:

[Parameter containing:
tensor([[1.,2.,3.],[4.,5.,6.]])]
tensor([[1.,6.],[0.,0.,0.],0.]])

谁能帮帮我?

解决方法

source code中所示,

return F.embedding(
    input,self.weight,self.padding_idx,self.max_norm,self.norm_type,self.scale_grad_by_freq,self.sparse) 

它使用了 functional embedding bag,它将 offsets 参数解释为

offsets(LongTensor,可选)– 仅在输入为 1D 时使用。 offsets 确定输入中每个包(序列)的起始索引位置。

EmbeddingBag docs 中:

如果输入是形状 (N) 的一维,它将被视为多个袋子(序列)的串联。 offsets 必须是一个一维张量,其中包含输入中每个包的起始索引位置。 因此,对于形状 (B) 的偏移量,输入将被视为具有 B 个袋子。 空袋子(即长度为 0)将返回由零填充的向量。

最后一条语句(“空包(即长度为 0)将返回由零填充的向量。”) 解释了结果张量中的零向量。

,
import torch
import torch.nn as nn

weight = torch.FloatTensor([[1,2,3],[4,5,6]])
embedding_sum = nn.EmbeddingBag.from_pretrained(weight,mode='sum')
print(embedding_sum.weight)

""" output
Parameter containing:
tensor([[1.,2.,3.],[4.,5.,6.]])
"""
input = torch.LongTensor([0,1])
offsets = torch.LongTensor([0,1,1])

根据这些偏移量,您将获得以下样本

"""
sample_1: input[0:1] # tensor([0])
sample_2: input[1:2] # tensor([1])
sample_3: input[2:1] # tensor([])
sample_4: input[1:]  # tensor([1])
"""

嵌入上面的示例

# tensor([0]) => lookup 0  => embedding_sum.weight[0] => [1.,3.]
# tensor([1]) => lookup 1  => embedding_sum.weight[1] => [4.,6.]
# tensor([])  => empty bag                            => [0.,0.,0.]
# tensor([1]) => lookup 1  => embedding_sum.weight[1] => [4.,6.]

print(embedding_sum(input,offsets))

""" output
tensor([[1.,6.],[0.,0.],6.]])
"""

再举一个例子:

input = torch.LongTensor([0,0])

根据这些偏移量,您将获得以下样本

"""
sample_1: input[0:1] # tensor([0])
sample_2: input[1:0] # tensor([])
sample_3: input[0:]  # tensor([0,1])
"""

嵌入上面的示例

# tensor([0])    => lookup 0 => embedding_sum.weight[0] => [1.,3.]
# tensor([])     => empty bag => [0.,0.]
# tensor([0,1]) => lookup 0 and 1 then reduce by sum 
#                => embedding_sum.weight[0] + embedding_sum.weight[1] => [5.,7.,9.]

print(embedding_sum(input,[5.,9.]])
"""

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。