Search code examples
pythonc++pytorchlibtorch

In the PyTorch C++ extension, how can I access a single element in a tensor and convert it to a standard c++ datatype?


I am writing a c++ extension for pytorch, in which I need to access the elements of a tensor by index, and I also need to convert the element to a standard c++ type. Here is a short example. Suppose I have a 2d tensor a and I need to access a[i][j] and convert it to float.

#include <torch/extension.h>

float get(torch::Tensor a, int i, int j) {
    return a[i][j];
}

The above is put into a file called tensortest.cpp. In another file setup.py I write

from setuptools import setup, Extension
from torch.utils import cpp_extension

setup(name='tensortest',
      ext_modules=[cpp_extension.CppExtension('tensortest_cpp', ['tensortest.cpp'])],
      cmdclass={'build_ext': cpp_extension.BuildExtension})

When I run python setup.py install the compiler reports the following error

running install
running bdist_egg
running egg_info
creating tensortest.egg-info
writing tensortest.egg-info/PKG-INFO
writing dependency_links to tensortest.egg-info/dependency_links.txt
writing top-level names to tensortest.egg-info/top_level.txt
writing manifest file 'tensortest.egg-info/SOURCES.txt'
/home/trisst/.local/lib/python3.8/site-packages/torch/utils/cpp_extension.py:335: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
  warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'tensortest.egg-info/SOURCES.txt'
writing manifest file 'tensortest.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'tensortest_cpp' extension
creating build
creating build/temp.linux-x86_64-3.8
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/user/.local/lib/python3.8/site-packages/torch/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/user/.local/lib/python3.8/site-packages/torch/include/TH -I/home/user/.local/lib/python3.8/site-packages/torch/include/THC -I/usr/include/python3.8 -c tensortest.cpp -o build/temp.linux-x86_64-3.8/tensortest.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=tensortest_cpp -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
In file included from /home/user/.local/lib/python3.8/site-packages/torch/include/ATen/Parallel.h:149,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/utils.h:3,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h:5,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn.h:3,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /home/user/.local/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
                 from tensortest.cpp:1:
/home/user/.local/lib/python3.8/site-packages/torch/include/ATen/ParallelOpenMP.h:84: warning: ignoring #pragma omp parallel [-Wunknown-pragmas]
   84 | #pragma omp parallel for if ((end - begin) >= grain_size)
      | 
tensortest.cpp: In function ‘float get(at::Tensor, int, int)’:
tensortest.cpp:4:15: error: cannot convert ‘at::Tensor’ to ‘float’ in return
    4 |  return a[i][j];
      |               ^
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1

What can I do?


Solution

  • Edited

    #include <torch/extension.h>
    
    float get(torch::Tensor a, int i, int j) 
    {
        return a[i][j].item<float>(); 
    }