微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

无法理解pytorch张量广播

如何解决无法理解pytorch张量广播

我有以下代码

import torch
d = 2
n = 50
X = torch.randn(n,d)
z = torch.tensor([[-1.0],[2.0]])
y = X @ z
X.size()
z.size()
y.size()

输出为:

torch.Size([50,2])
torch.Size([2,1])
torch.Size([50,1])

我的问题是,为什么广播后结果y的大小是[50,1]而不是[50,2],我认为应该是[50,2],对吗?

解决方法

import textwrap def wrap(string,max_width): return textwrap.fill(string,max_width) if __name__ == '__main__': string,max_width = input(),int(input()) result = wrap(string,max_width) print(result) 不是广播而是乘法。

在python 3.5中,为矩阵引入了@运算符 PEP465之后的乘法。这例如被实现在@中 作为numpy运算符。

所以matmul的大小很好。

将大小为y的矩阵与大小为[50,2]的向量相乘,将输出大小为[2,1]的向量。

一个更清楚地显示它的示例是:

[50,1]

如您所见,第三个输出确实只是两个张量的乘积。

如果您想进行广播,我建议您参考https://medium.com/ai%C2%B3-theory-practice-business/understanding-broadcasting-in-pytorch-ca9e9533f05f

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。