微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

从具有2D索引张量的3D张量检索元素

如何解决从具有2D索引张量的3D张量检索元素

我正在玩GPT2,我有2个张量:

O :形状为(B,S-1,V)的输出张量,其中B是批处理大小,S是时间步数,V是词汇量。这是生成模型的输出,并且在第二维上被软最大化。

L :2D张量形状(B,S-1),其中每个元素是每个样本的每个时间步的正确标记的索引。这基本上是标签

我想基于张量 L 从张量 O 提取相应正确令牌的预测概率,以使最终得到2D张量形状(B, S)。除了使用循环之外,还有一种有效的方法吗?

解决方法

作为参考,我的回答基于this Medium article
本质上,假设两个张量都只是规则的torch.gather(或可以转换为一个),您的答案就位于torch.Tensor中。

import torch

# Specify some arbitrary dimensions for now
B = 3
V = 6
S = 4

# Make example reproducible
torch.manual_seed(42)

# L necessarily has to be a torch.LongTensor,otherwise indexing will fail.
L = torch.randint(0,V,size=[B,S])

O = torch.rand([B,S,V])

# Now collect the results. L needs to have similar dimension,# except in the axis you want to collect along.
X = torch.gather(O,dim=2,index=L.unsqueeze(dim=2))

# Make sure X has no "unnecessary" dimension
X = X.squeeze(dim=2)

很难看出这是否产生了正确的正确结果,这就是为什么我包含了一个随机种子,可以使示例确定结果,并且您可以轻松地验证它是否获得了所需的结果。但是,为澄清起见,也可以使用一维较低的张量,对此,torch.gather的作用将变得更加清楚。

请注意,torch.gather还允许您从理论上索引同一行中的多个索引。意思是,如果取而代之的是一个具有多个正确值的多类示例,则可以类似地使用形状为L的张量[B,number_of_correct_samples]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。