如何解决PyTorch C ++扩展:如何索引张量并更新它?
我正在创建PyTorch C ++扩展,经过大量研究,我不知道如何索引张量和更新其值。我发现了如何使用data_ptr()
方法对张量的条目进行迭代,但这不适用于我的用例。
给出的是一个矩阵M
,一个索引对P
的列表(块)列表以及一个函数f: dtype(M)^2 -> dtype(M)^2
,该函数接受两个值并吐出两个新值。
我正在尝试实现以下伪代码:
for each block B in P:
for each row R in M:
for each index-pair (i,j) in B:
M[R,i],M[R,j] = f(M[R,j])
毕竟,这段代码将使用CUDA在GPU上运行,但是由于我对此没有任何经验,所以我想先编写一个纯C ++程序然后进行转换。
任何人都可以建议这样做或如何将算法转换为等效功能吗?
解决方法
我想做的事可以使用
tensor.accessor<scalar_dtype,num_dimensions>()
方法。如果在GPU上执行,请使用scalars.packed_accessor64<scalar_dtype,num_dimensions,torch::RestrictPtrTraits>()
要么
scalars.packed_accessor32<scalar_dtype,torch::RestrictPtrTraits>()
(取决于张量的大小)。
auto num_rows = scalars.size(0);
matrix = torch::rand({10,8});
auto a = matrix.accessor<float,2>();
for (auto i = 0; i < num_rows; ++i) {
auto x = a[i][some_index];
auto new_x = some_function(x);
a[i][some_index] = new_x;
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。