微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 PyO3 将列表列表作为参数从 Python 传递给 Rust

如何解决使用 PyO3 将列表列表作为参数从 Python 传递给 Rust

我正在尝试使用 Py03 将列表从 Python 传递到 Rust。我试图将其传递给的函数具有以下签名:

pub fn k_nearest_neighbours(k: usize,x: &[[f32; 2]],y: &[[f32; 3]]) -> Vec<Option<f32>> 

我正在为预先存在的库编写绑定,因此我无法更改原始代码。我目前的做事方式是这样的:

// This is example code == DOES NOT WORK
#[pyfunction] // make a new function within a new library with pyfunction macro
fn k_nearest_neighbours(k: usize,x: Vec<Vec<f32>>,y: Vec<f32>) -> Vec<Option<f32>> {
    // reformat input where necessary
    let x_slice = x.as_slice();
    // return original lib's function return
    classification::k_nearest_neighbours(k,x_slice,y)
}

x.as_slice() 函数几乎完成了我需要它做的事情,它给了我一个向量切片 &[Vec<f32>],而不是切片 &[[f32; 3]]

我希望能够运行此 Python 代码

from rust_code import k_nearest_neighbours as knn  # this is the Rust function compiled with PyO3

X = [[0.0,1.0],[2.0,3.0],[4.0,5.0],[0.06,7.0]]
train = [
        [0.0,0.0,0.0],[0.5,0.5,[3.0,3.0,]

k = 2
y_true = [0,1,1]
y_test = knn(k,X,train)
assert(y_true == y_test)

解决方法

查看 k_nearest_neighbours 的签名表明它期望 [f32; 2][f32; 3] 是数组,而不是切片(例如 &[f32]

数组在编译时有一个静态已知的大小,而切片是动态大小的。向量也是如此,您无法控制示例中内部向量的长度。因此,您最终会得到从输入向量到预期数组的错误转换。

您可以使用 TryFrom 将切片转换为数组,即:

use std::convert::TryFrom;
fn main() {
    let x = vec![vec![3.5,3.4,3.6]];
    let x: Result<Vec<[f32; 3]>,_> = x.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>,_>>();
}

综上所述,您的函数将需要在输入错误时返回错误,并且您需要创建一个包含可传递给函数的数组的新向量:

#[pyfunction] // make a new function within a new library with pyfunction macro
fn k_nearest_neighbours(k: usize,x: Vec<Vec<f32>>,y: Vec<f32>) -> PyResult<Vec<Option<f32>>> {
    let x = x.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>,_>>();
    let y = y.into_iter().map(TryFrom::try_from).collect::<Result<Vec<_>,_>>();
    // Error handling is missing here,you'll need to look into PyO3's documentation for that
    ...
    // return original lib's function return
    Ok(classification::k_nearest_neighbours(k,&x,&y))
}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。