微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

3 维张量上的串联张量重塑

如何解决3 维张量上的串联张量重塑

我有 2 个张量,

它们目前的格式分别是 [13,2]。我试图将两者组合成一个维度为 [2,13,2] 的 3 维张量,以便它们堆叠在彼此的顶部,但是作为批次分开。

这里是一个格式为 [13,2] 的张量的例子:

tensor([[[-1.8588,0.3776],[ 0.1683,0.2457],[-1.2740,0.5683],[-1.7262,0.4350],[-1.0160,0.5940],[-1.3354,0.5565],[-0.7497,0.5792],[-0.2024,0.4251],[ 1.0791,-0.2770],[ 0.3032,0.1706],[ 0.8681,-0.1607]])

我想保持形状,但将它们分成两组在同一个张量中。以下是我所追求的格式示例:

tensor([[[-1.8588,-0.1607]],[[-1.8588,-0.1607]]])

有人对如何使用串联来做到这一点有任何想法吗?我曾尝试在使用 torch.cat((a,b.unsqueeze(0)),dim=-1) 时使用 .unsqueeze,但是它将格式更改为 [13,4,1] 这不是我所追求的格式。

下面的解决方案有效,但是,我的想法是我将通过循环继续堆叠到 y ,而不受形状的限制。很抱歉没有足够清楚地表达我的想法。

它们的大小都是 [13,2] 所以它会以 [1,2],[2,[3,[4,2] 等等...

解决方法

在这种情况下,您需要使用 torch.stack 而不是 torch.cat,至少它更方便:

x1 = torch.randn(13,2)
x2 = torch.randn(13,2)
y = torch.stack([x1,x2],0) # creates a new dimension 0
print(y.shape)
>>> (2,13,2)

您确实可以使用 unsqueezecat,但是您需要解压两个输入张量:

x1 = torch.randn(13,2).unsqueeze(0) # shape: (1,2).unsqueeze(0) # same
y = torch.cat([x1,0)
print(y.shape)
>>> (2,2)

这里有一个有用的线程来理解差异:difference between cat and stack

如果你需要堆叠更多张量,其实也不是很难,stack 可以处理任意数量的张量:

# This list of tensors is what you will build in your loop
tensors = [torch.randn(13,2) for i in range(10)]
# Then at the end of the loop,you stack them all together
y = torch.stack(tensors,0)
print(y.shape)
>>> (10,2)

或者,如果您不想使用列表:

# first,build the y tensor to which the other ones will be appended
y = torch.empty(0,2)
# Then the loop,and don't forget to unsqueeze
for i in range(10):
    x = torch.randn(13,2).unsqueeze(0)
    y = torch.cat([y,x],0)

print(y.shape)
>>> (10,2)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。