如何解决在PyTorch C ++扩展中,如何访问张量中的单个元素并将其转换为标准c ++数据类型?
我正在为pytorch编写c ++扩展,在其中我需要按索引访问张量的元素,并且还需要将元素转换为标准c ++类型。这是一个简短的例子。假设我有一个二维张量a
,我需要访问a[i][j]
并将其转换为float。
#include <torch/extension.h>
float get(torch::Tensor a,int i,int j) {
return a[i][j];
}
以上内容放入名为tensortest.cpp
的文件中。在另一个文件setup.py
中写
from setuptools import setup,Extension
from torch.utils import cpp_extension
setup(name='tensortest',ext_modules=[cpp_extension.CppExtension('tensortest_cpp',['tensortest.cpp'])],cmdclass={'build_ext': cpp_extension.BuildExtension})
当我运行python setup.py install
时,编译器报告以下错误
running install
running bdist_egg
running egg_info
creating tensortest.egg-info
writing tensortest.egg-info/PKG-INFO
writing dependency_links to tensortest.egg-info/dependency_links.txt
writing top-level names to tensortest.egg-info/top_level.txt
writing manifest file 'tensortest.egg-info/SOURCES.txt'
/home/trisst/.local/lib/python3.8/site-packages/torch/utils/cpp_extension.py:335: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'tensortest.egg-info/SOURCES.txt'
writing manifest file 'tensortest.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'tensortest_cpp' extension
creating build
creating build/temp.linux-x86_64-3.8
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/user/.local/lib/python3.8/site-packages/torch/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/TH -I/home/user/.local/lib/python3.8/site-packages/torch/include/THC -I/usr/include/python3.8 -c tensortest.cpp -o build/temp.linux-x86_64-3.8/tensortest.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=tensortest_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
In file included from /home/user/.local/lib/python3.8/site-packages/torch/include/ATen/Parallel.h:149,from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/utils.h:3,from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h:5,from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn.h:3,from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/extension.h:4,from tensortest.cpp:1:
/home/user/.local/lib/python3.8/site-packages/torch/include/ATen/ParallelOpenMP.h:84: warning: ignoring #pragma omp parallel [-Wunknown-pragmas]
84 | #pragma omp parallel for if ((end - begin) >= grain_size)
|
tensortest.cpp: In function ‘float get(at::Tensor,int,int)’:
tensortest.cpp:4:15: error: cannot convert ‘at::Tensor’ to ‘float’ in return
4 | return a[i][j];
| ^
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1
我该怎么办?
解决方法
已编辑
#include <torch/extension.h>
float get(torch::Tensor a,int i,int j)
{
return a[i][j].item<float>();
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。