如何使用Pytorch和/或Numpy在矩阵的多维数组中有效地找到最大值的指标 解决方案一解决方案二解决方案三解决方案四基准化解决方案五解决方案六

如何解决如何使用Pytorch和/或Numpy在矩阵的多维数组中有效地找到最大值的指标 解决方案一解决方案二解决方案三解决方案四基准化解决方案五解决方案六

背景

在机器学习中,通常要处理高维数据。例如,在卷积神经网络(CNN)中,每个输入图像的尺寸可以为256x256,并且每个图像可以具有3个颜色通道(红色,绿色和蓝色)。如果我们假设模型一次获取16张图像,则进入CNN的输入维数为[16,3,256,256]。每个单独的卷积层都希望数据以[batch_size,in_channels,in_y,in_x]的形式出现,并且所有这些数量通常会逐层更改(batch_size除外)。我们用于由[in_y,in_x]值组成的矩阵的术语是 feature map ,这个问题涉及在给定的每个特征图中查找最大值及其索引。层。

为什么要这样做?我想对每个要素地图应用一个蒙版,并且我要以每个要素地图中的最大值为中心应用该蒙版,为此,我需要知道每个最大值位于何处。此蒙版应用程序是在模型的训练和测试期间完成的,因此效率对于减少计算时间至关重要。有许多Pytorch和Numpy解决方案可用于查找单例最大值和索引,以及沿单个维度查找最大值或索引,但是(我可以找到)没有专用且 efficiency 内置函数一次查找沿2个或多个维度的最大值的索引。是的,我们可以嵌套在单个维度上运行的函数,但这是一些效率最低的方法。

我尝试过的东西

  • 我看过this Stackoverflow question,但是作者正在处理一种特殊情况的4D数组,该数组被微不足道地压缩为3D数组。接受的答案专门针对这种情况,指向TopK的答案是错误的,因为它不仅作用于单一维度,而且鉴于提出的问题,将有必要k=1,从而演变为常规{{1} }。
  • 我看过this Stackoverflow question,但是这个问题及其答案都集中在一个维度上。
  • 我看过this Stackoverflow question,但是我已经知道答案的方法了,因为我在自己的答案here中独立地制定了答案(我认为该方法效率很低)。
  • 我看过this Stackoverflow question,但可接受的答案是“在PyTorch中无法在多个维度上执行.min()或.max()”。尽管给出了一种解决方法,但它不能满足这个问题的关键部分,即效率。
  • 我还阅读了许多其他Stackoverflow问题和答案,以及Numpy文档,Pytorch文档以及Pytorch论坛上的帖子。
  • 我已经尝试了很多解决此问题的方法,以至于我创建了这个问题,以便我可以回答并回馈给社区,以及任何正在寻找解决此问题的方法的人。未来。

性能标准

如果我要问有关效率的问题,则需要详细说明期望。我试图为上面的问题找到一个省时的解决方案(空格是次要的),而无需编写C代码/扩展,并且该解决方案相当灵活(我不追求超级专门的方法)。该方法必须接受数据类型为float32或float64的torch.max Torch张量作为输入,并输出数据类型为int32或int64的[a,b,c,d]形式的数组或张量(因为我们将输出用作索引)。 解决方案应以以下典型解决方案为基准:

[a,2]

解决方法

方法

我们将利用Numpy社区和库,以及Pytorch张量和Numpy数组可以相互转换而无需在内存中复制或移动基础数组的事实(因此转换成本很低) )。来自Pytorch documentation

将火炬张量转换为Numpy数组,反之亦然,这很容易。火炬的Tensor和Numpy数组将共享其底层内存位置,而更改一个将更改另一个。

解决方案一

我们首先要使用Numba library编写一个函数,该函数将在首次使用时进行即时(JIT)编译,这意味着我们可以获得C速度而不必自己编写C代码。当然,可以得到JIT编辑的内容有一些警告,其中之一是我们使用Numpy函数。但这还算不错,因为要记住,从我们的火炬张量转换为Numpy成本低。我们创建的功能是:

@njit(cache=True)
def indexFunc(array,item):
    for idx,val in np.ndenumerate(array):
        if val == item:
            return idx

此功能来自另一个位于here的Stackoverflow答案(这是将我介绍给Numba的答案)。该函数采用N维Numpy数组,并查找给定item的首次出现。成功匹配后,它将立即返回找到的项目的索引。 @njit装饰器是@jit(nopython=True)的缩写,它告诉编译器我们希望它使用 no Python对象来编译函数,如果不能,则抛出错误这样做(当不使用Python对象时,Numba最快,而我们追求的是速度)。

有了此快速函数的支持,我们可以获得张量中最大值的索引,如下所示:

import numpy as np

x =  x.numpy()
maxVals = np.amax(x,axis=(2,3))
max_indices = np.zeros((n,p,2),dtype=np.int64)
for index in np.ndindex(x.shape[0],x.shape[1]):
    max_indices[index] = np.asarray(indexFunc(x[index],maxVals[index]),dtype=np.int64)
max_indices = torch.from_numpy(max_indices)

我们使用np.amax是因为它的axis参数可以接受一个元组,从而允许它返回4D输入中每个2D特征图的最大值。由于appending to numpy arrays is expensive,我们提前max_indices初始化了np.zeros,所以我们提前分配了所需的空间。这种方法比问题中的“典型解决方案”快了许多(一个数量级),但是它在JIT-ed函数外部也使用了for循环,因此我们可以进行改进。 ..

解决方案二

我们将使用以下解决方案:

@njit(cache=True)
def indexFunc(array,val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True,parallel=True)
def indexFunc2(x,maxVals):
    max_indices = np.zeros((x.shape[0],x.shape[1],dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            max_indices[i,j] = np.asarray(indexFunc(x[i,j],maxVals[i,j]),dtype=np.int64)
    return max_indices

x = x.numpy()
maxVals = np.amax(x,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))

我们可以使用Numba的for函数(其行为与prange相同,而不是通过range循环一次遍历功能映射,但告诉编译器我们希望循环并行化)和parallel=True装饰器参数。 Numba也parallelizes the np.zeros function。因为我们的函数是即时编译的,并且不使用Python对象,所以Numba可以利用我们系统中所有可用的线程!值得注意的是,raise RuntimeError中现在有一个indexFunc。我们需要包括它,否则Numba编译器将尝试推断该函数的返回类型,并推断它将是数组还是None。这与我们在indexFunc2中的用法并不吻合,因此编译器会抛出错误。当然,根据我们的设置我们知道indexFunc将始终返回一个数组,因此我们可以简单地在另一个逻辑分支中引发和出错。

此方法在功能上与解决方案一相同,但是使用nd.index将迭代使用for更改为两个prange循环。这种方法比解决方案一快约四倍。

解决方案三

解决方案二速度很快,但是仍然可以使用常规Python查找最大值。我们可以使用更全面的JIT编辑功能来加快速度吗?

@njit(cache=True)
def indexFunc(array,parallel=True)
def indexFunc3(x):
    maxVals = np.zeros((x.shape[0],x.shape[1]),dtype=np.float32)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxVals[i][j] = np.max(x[i][j])
    max_indices = np.zeros((x.shape[0],dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            x[i][j] == np.max(x[i][j])
            max_indices[i,dtype=np.int64)
    return max_indices

max_indices = torch.from_numpy(indexFunc3(x))

此解决方案中似乎还有很多事情要做,但是唯一的变化是,我们现在使操作并行化,而不是使用np.amax计算每个特征图的最大值。这种方法比解决方案二要快。

解决方案四

此解决方案是我能提出的最好的解决方案:

@njit(cache=True,parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2],maxTemp % x.shape[2]] 
    return max_indices

max_indices = torch.from_numpy(indexFunc4(x))

这种方法更简洁,并且比解决方案三更快,最快达到33%,比典型解决方案快50倍。我们使用np.argmax来获取每个特征图最大值的索引,但是 np.argmax仅返回索引,就好像每个特征图都被展平了一样。也就是说,我们得到一个整数,告诉我们元素在要素图中的编号,而不是我们需要能够访问该元素的索引。数学[maxTemp // x.shape[2],maxTemp % x.shape[2]]是将单数转换为我们需要的[row,column]

基准化

所有方法都针对形状为[32,d,64,64]的随机输入进行了基准测试,其中d从5增加到245。对于每个d,收集了15个样本并将时间平均。相等性测试确保所有解决方案均提供相同的值。基准输出的一个示例是:

Solution Benchmarks

随着d的增加,基准时间的图是(省略了“典型解”,所以该图不会被压缩):

Benchmarking Plot

哇!这些峰值开始时是怎么回事?

解决方案五

Numba允许我们产生即时编译功能,但是直到我们第一次使用它们时才编译它们。然后,当我们再次调用该函数时,它将缓存结果。这意味着我们第一次调用JIT版本的函数时,会在编译函数时出现计算时间高峰。幸运的是,有一种解决方法-如果我们提前指定函数的返回类型和参数类型将是什么,该函数将被急切地编译,而不是即时编译。将这些知识应用于解决方案四,我们得到:

@njit('i8[:,:,:](f4[:,:])',cache=True,maxTemp % x.shape[2]] 
    return max_indices    

max_indices6 = torch.from_numpy(indexFunc4(x))

如果重新启动内核并重新运行基准测试,我们可以查看第一个结果d==5和第二个结果d==10,请注意,所有JIT版本的解决方案在d==5,因为它们必须被编译,但解决方案四除外,因为我们提前明确提供了函数签名:

enter image description here

我们去了!这是迄今为止我对这个问题的最佳解决方案。


编辑#1

解决方案六

已开发出一种改进的解决方案,它比以前发布的最佳解决方案快33%。此解决方案仅在输入数组为C连续的情况下才有效,但这不是一个很大的限制,因为numpy数组或割炬张量将是连续的,除非对其进行整形,并且如果需要,它们都具有使数组/张量连续的功能。 / p>

此解决方案与先前的最佳解决方案相同,但是用于指定输入和返回类型的函数装饰器已更改

@njit('i8[:,parallel=True)

@njit('i8[:,::1](f4[:,::1])',parallel=True)

唯一的区别是,每个数组类型中的最后一个:变为::1,这向numba njit编译器表明输入数组是C连续的,从而可以更好地进行优化。

那么完整的解决方案六是:

@njit('i8[:,parallel=True)
def indexFunc5(x):
    max_indices = np.zeros((x.shape[0],maxTemp % x.shape[2]] 
    return max_indices 

max_indices7 = torch.from_numpy(indexFunc5(x))

包含此新解决方案的基准测试可以确认加速:

Benchmark including solution 6

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res