微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

用于浮点数或复杂输入的 rust-ndarray 函数

如何解决用于浮点数或复杂输入的 rust-ndarray 函数

问题:是否可以用 rust 编写一个函数来接受浮点数或复杂的 ndarrays 作为输入?

我是 Rust 的新手,我来自 python/numpy 领域,在那里浮点数组和复杂数组可以很好地结合在一起。所以当我写一个函数时,我不会担心一个或几个输入是否复杂。

所以,我想写一个这样的函数

use ndarray::{ArrayD,ArrayViewD};
use num_complex::Complex64;
use f64;

fn example(a: f64,x: ArrayViewD<'_,Complex64>,y: ArrayViewD<'_,f64>) -> ArrayD<Complex64> {
    &x * a + &y
}

但是使输入通用。我猜应该是这样的:

fn example<T: ???>(a: T,T>,T>) -> ArrayD<T> {
    &x * a + &y
}

但我不确定需要什么特征。

我想蛮力方法是将一切都强制复杂化?可以工作,但需要大量转换,内存翻倍,而且一般感觉不太对。

我上面的示例是对 PyO3 示例的修改https://github.com/PyO3/rust-numpy/blob/main/examples/simple-extension/src/lib.rs


编辑

我仍然缺少一些东西。如果我尝试:

fn example<T: Num>(a: T,T>) -> ArrayD<T> {     
&x * a + y 
}

我明白了:

error[E0369]: cannot multiply `&ArrayBase<ViewRepr<&T>,Dim<IxDynImpl>>` by `T`
  --> src\main.rs:26:8
   |
26 |     &x * a + &y
   |     -- ^ - T
   |     |
   |     &ArrayBase<ViewRepr<&T>,Dim<IxDynImpl>>
   |
help: consider further restricting this bound
   |
25 | fn example<T: Num + std::ops::Mul<Output = T>>(a: T,T>) -> ArrayD<T> {
   |                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^

error[E0308]: mismatched types
  --> src\main.rs:26:14
   |
25 | fn example<T: Num>(a: T,T>) -> ArrayD<T> {
   |            - this type parameter
26 |     &x * a + &y
   |              ^^ expected type parameter `T`,found reference
   |
   = note: expected type parameter `T`
                   found reference `&ArrayBase<ViewRepr<&T>,Dim<IxDynImpl>>`
 

所以它想让我添加 std::ops::Mul<Output = T>。但是如果我添加它,我会得到一个类似的错误,它想让我再次添加它(递归?)。

我认为问题在于我的输入 axy 都可以是不同的类型,这与输出类型不同。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。