如何解决在numpy数组上迭代时,我无法调用存储在数组中的对象的方法
StackOverflow中问的第一个问题,因此,欢迎您提出如何更好地“询问”的提示。
这部分代码的基本目标: 许多球(no_balls)沿随机方向移动。
我正在尝试从python列表移动到numpy数组,以提高性能。这是简化的代码。
基本问题: 我的迭代器为我提供了ndarray类型的对象,而不是vpy.sphere,因此在我要遍历的对象上调用sphere.pos失败。 还是这不可能,因为Numpy是为数字构建的?性能的替代品?
import vpython as vpy
import numpy as np
#Create and Fill numpy array with random size balls
balls = np.empty([no_ball],dtype=vpy.sphere)
with np.nditer(balls,flags=['refs_ok'],op_flags=['readwrite']) as b_it:
debug_msg(len(b_it))
for b in b_it:
b[...] = (vpy.sphere( radius=random_in_range(ball_min_r,ball_max_r),opacity=0.8,color=random_RGB(),pos=vpy.vector(0,0),))
debug_msg('populated balls list')
#Main Loop
debug_msg('Starting Main Loop')
while True:
vpy.rate(30)
with np.nditer(balls,op_flags=['readwrite']) as b_it:
#Main Loop
debug_msg('Starting Main Loop')
while True:
vpy.rate(30)
#The actual loop manipulates the position but the problem is that I can't access the position of the sphere objects. Type returns nd.array for b
for b in b_it:
debug_msg(type(b[...]))
debug_msg(b[...].pos)
#Above outputs
<class 'numpy.ndarray'>
Traceback (most recent call last):
File "path",line 93,in <module>
debug_msg(b[...].pos)
AttributeError: 'numpy.ndarray' object has no attribute 'pos'
如何调用数组中对象的方法和成员。在一个旁注中,为什么我需要调用b [...]而不是b似乎已经过时了。
解决方法
一个简单的类:
In [149]: class Foo():
...: def __init__(self,i):
...: self.i = i
...: def __repr__(self):
...: return f'<FOO {self.i}>'
...:
In [150]: Foo(323)
Out[150]: <FOO 323>
此类对象的列表:
In [151]: alist = [Foo(i) for i in range(10)]
等效对象dtype数组:
In [152]: arr = np.array(alist)
In [153]: arr.dtype
Out[153]: dtype('O')
In [154]: arr
Out[154]:
array([<FOO 0>,<FOO 1>,<FOO 2>,<FOO 3>,<FOO 4>,<FOO 5>,<FOO 6>,<FOO 7>,<FOO 8>,<FOO 9>],dtype=object)
从列表中获取属性:
In [155]: [f.i for f in alist]
Out[155]: [0,1,2,3,4,5,6,7,8,9]
In [156]: timeit [f.i for f in alist]
826 ns ± 8.9 ns per loop (mean ± std. dev. of 7 runs,1000000 loops each)
并从数组中(较慢):
In [157]: timeit [f.i for f in arr]
1.66 µs ± 15.5 ns per loop (mean ± std. dev. of 7 runs,1000000 loops each)
使用nditer
-您对文档进行了足够的研究以使标记正确,但并没有意识到b
是一个数组,而不是Foo
:
In [158]: with np.nditer(arr,flags=['refs_ok'],op_flags=['readwrite']) as b_it:
...: for b in b_it:
...: print(b,b.dtype,b.shape,b.item())
...:
<FOO 0> object () <FOO 0>
<FOO 1> object () <FOO 1>
<FOO 2> object () <FOO 2>
<FOO 3> object () <FOO 3>
<FOO 4> object () <FOO 4>
<FOO 5> object () <FOO 5>
<FOO 6> object () <FOO 6>
<FOO 7> object () <FOO 7>
<FOO 8> object () <FOO 8>
<FOO 9> object () <FOO 9>
获取属性列表:
In [159]: res = []
...: with np.nditer(arr,op_flags=['readwrite']) as b_it:
...: for b in b_it:
...: res.append(b.item().i)
...:
...:
In [160]: res
Out[160]: [0,9]
时间差:
In [161]: %%timeit
...: res = []
...: with np.nditer(arr,op_flags=['readwrite']) as b_it:
...: for b in b_it:
...: res.append(b.item().i)
...:
7.25 µs ± 60.7 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)
对对象数组的元素执行操作的一种更简洁的方法是使用frompyfunc
:
In [162]: f = np.frompyfunc(lambda b:b.i,1)
In [163]: f(arr)
Out[163]: array([0,9],dtype=object)
In [164]: timeit f(arr)
2.1 µs ± 8.58 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)
仍然比迭代慢,尽管如果我们想要一个数组而不只是一个列表,它比:
In [165]: timeit np.array([f.i for f in arr])
5.79 µs ± 21.4 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)
nditer
文档需要更强大的性能免责声明。在nditer
或c
代码中使用cython
既有用又快速,但是通过Python代码访问时,它不及其他明显的替代方法。额外的bells-n-whistles在某些情况下可能很有用,但我大多将其视为正确编译代码的桥梁,而不是其目的。
性能问题的核心是Foo
是一个Python类。因此,访问i
属性必须使用完整的Python引用系统。它不能使用任何快速编译的numpy
数字方法。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。