torch.argmax(input) → LongTensor
参数:
-
input (Tensor) – 输入的Tensor矩阵
-
dim (int) – dim表示不同维度。特别的在dim=0表示二维矩阵中的列,dim=1在二维矩阵中的行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么dim=0就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,以此类推。
举一些例子说明:
import torch
x = torch.asarray([3, 2, 5, 1])
y = torch.argmax(x) # 对应于x中最大元素的索引值
print(x, y)
torch.argmax( ) 不使用dim
返回最大值索引,也就是5的索引位置2.
import torch
x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x) # 对应于x中最大元素的索引值
print(x, y)
该函数默认将输入矩阵排变成一个一维向量,然后找出这个一维向量里面最大值的索引。
torch.argmax( )使用参数dim
对于dim
这个参数可以这样理解:
下边代码例子输入x为torch.Size([2, 4])
,dim=0
时把2变成1,返回每列最大索引,dim=1
时把4变为1,返回每行最大索引。
import torch
x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x, dim=0) # 对应于x中最大元素的索引值
print(y)
import torch
x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x, dim=1) # 对应于x中最大元素的索引值
print(y)
原文地址:https://www.jb51.cc/wenti/3282073.html
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。