理解 numpy np.tensordot

如何解决理解 numpy np.tensordot

arr1 = np.arange(8).reshape(4,2)
arr2 = np.arange(4,12).reshape(2,4)
ans=np.tensordot(arr1,arr2,axes=([1],[0]))
ans2=np.tensordot(arr1,axes=([0],[1]))
ans3 = np.tensordot(arr1,axes=([1,0],[0,1]))

我想了解这个 tensordot 函数是如何工作的。我知道它返回 tensordot 产品。

但是轴部分对我来说有点难以理解。我观察到的

对于 ans,它就像数组 arr1 中的列数和 arr2 中的行数构成了最终矩阵。

对于 ans2,它是 arr2 中的列数和 arr1 中的行数的另一种方式

我不明白轴=([1,1])。让我知道我对 ans 和 ans2 的理解是否正确

解决方法

据我从 tensordot 文档中了解到,您提供的是 ans、ans2 和 ans3 中的轴列表(ans 和 ans2 在列表中只有一个元素)。然后,此列表指定要对哪些轴求和。您对 ans 和 ans2 的假设是正确的,其中在 ans 中,您的第一个元素是 arr1 的 0 轴(arr1 中的行)和 arr2 中的 1 轴(arr2 中的列)。我不完全确定对 ans3 有什么期望,但我可能会尝试自己运行一些示例并查看一下。我希望这能让你更好地理解

链接:https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html

,

你忘记显示数组了:

In [87]: arr1
Out[87]: 
array([[0,1],[2,3],[4,5],[6,7]])
In [88]: arr2
Out[88]: 
array([[ 4,5,6,7],[ 8,9,10,11]])
In [89]: ans
Out[89]: 
array([[  8,11],[ 32,37,42,47],[ 56,65,74,83],[ 80,93,106,119]])
In [90]: ans2
Out[90]: 
array([[ 76,124],[ 98,162]])
In [91]: ans3
Out[91]: array(238)

ans 只是普通的点矩阵乘积:

In [92]: np.dot(arr1,arr2)
Out[92]: 
array([[  8,119]])

dot 乘积总和在 ([1],[0])arr1 轴 1 和 arr2 的轴 0 上执行(常规跨列,向下行)。使用 2d 'sum cross ...' 短语可能会令人困惑。处理 1 维或 3 维数组时更清晰。这里匹配大小 2 维相加,留下 (4,4)。

ans2 反转它们,对 4 求和,产生 (2,2):

In [94]: np.dot(arr2,arr1)
Out[94]: 
array([[ 76,98],[124,162]])

tensordot 刚刚转置了 2 个数组并执行了常规的 dot

In [95]: np.dot(arr1.T,arr2.T)
Out[95]: 
array([[ 76,162]])

ans3 使用转置和重塑 (ravel),在两个轴上求和:

In [98]: np.dot(arr1.ravel(),arr2.T.ravel())
Out[98]: 238

通常,tensordot 使用转置和重塑的混合将问题简化为二维 np.dot 问题。然后它可以对结果进行整形和转置。

我发现 einsum 的尺寸控制更清晰:

In [99]: np.einsum('ij,jk->ik',arr1,arr2)
Out[99]: 
array([[  8,119]])
In [100]: np.einsum('ji,kj->ik',arr2)
Out[100]: 
array([[ 76,162]])
In [101]: np.einsum('ij,ji',arr2)
Out[101]: 238

随着 einsummatmul/@ 的发展,tensordot 变得不那么必要了。它更难理解,并且没有任何速度或灵活性优势。不要担心理解它。

ans3 是其他 2 个 ans 的迹线(对角线的和):

In [103]: np.trace(ans)
Out[103]: 238
In [104]: np.trace(ans2)
Out[104]: 238

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?