微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

torch.argmax()使用

torch.argmax(input) → LongTensor
参数:

  • input (Tensor) – 输入的Tensor矩阵

  • dim (int) – dim表示不同维度。特别的在dim=0表示二维矩阵中的列,dim=1在二维矩阵中的行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么dim=0就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,以此类推。

举一些例子说明:

import torch

x = torch.asarray([3, 2, 5, 1])
y = torch.argmax(x)  # 对应于x中最大元素的索引值
print(x, y)

torch.argmax( ) 不使用dim

在这里插入图片描述


返回最大值索引,也就是5的索引位置2.

import torch

x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x)  # 对应于x中最大元素的索引值
print(x, y)

在这里插入图片描述


函数认将输入矩阵排变成一个一维向量,然后找出这个一维向量里面最大值的索引。

torch.argmax( )使用参数dim

对于dim这个参数可以这样理解:
下边代码例子输入x为torch.Size([2, 4])dim=0时把2变成1,返回每列最大索引,dim=1时把4变为1,返回每行最大索引。

函数返回其他所有维在这个维度上面张量最大值的索引。

import torch

x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x, dim=0)  # 对应于x中最大元素的索引值
print(y)

在这里插入图片描述

import torch

x = torch.asarray([[3, 2, 5, 1], [3, 11, 6, 2]])
y = torch.argmax(x, dim=1)  # 对应于x中最大元素的索引值
print(y)

在这里插入图片描述

原文地址:https://www.jb51.cc/wenti/3282073.html

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐