如何解决Numba namedtuple签名
我正在尝试为Numba中的namedtuple指定返回类型,但我无法这样做。有人可以帮忙吗?考虑以下最少的代码:
import numba as nb
from collections import namedtuple
NT = namedtuple('NT',['sum','sum2'])
@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
arraysum = 0.0
out = NT(sum=arraysum,sum2=arraysum)
return out
我得到了错误
No conversion from NT(float64 x 2) to NT(float64,float64) for '$20return_value.7',defined at None
File "numbanamedtuple.py",line 10:
def arrsum_njit(nn,xx):
<source elided>
out = NT(sum=arraysum,sum2=arraysum)
return out
^
During: typing of assignment at numbanamedtuple.py (10)
File "numbanamedtuple.py",sum2=arraysum)
return out
解决方法
问题是“过度优化”的numba编译器(错误)。在元组中添加其他类型的变量,以告知编译器使用异构元组(内部类)。
import numba as nb
from collections import namedtuple
NT = namedtuple('NT',['sum','sum2','dummy'])
@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64,nb.int64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
arraysum = 0.0
out = NT(sum=arraysum,sum2=arraysum,dummy=1)
return out
更新: 经过测试:
- Numba 0.51.2 / Windows
- Numba 0.48.0 / Google colab-Linux Ubuntu 18.04.5 LTS
改用 NamedUniTuple
。这是同构命名元组的 numba 规范类型。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。