Search code examples
c++c++11pybind11

pybind11 - Return a shared_ptr of std::vector


I have a member variable that stores a std::shared_ptr of std::vector<uint32_t>. I want to create a Python binding for test_func2() so that I can access that vector without any additional copy. Here is a skeleton code.

#include <vector>
#include <memory>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

namespace py = pybind11;



class TestLoader
{
private:
    std::shared_ptr<std::vector<uint32_t>>  tileData;

public:
    TestLoader();
    ~TestLoader();
    void test_func1();
    std::shared_ptr<std::vector<uint32_t>> test_func2()  const;
};

void  TestLoader::test_func1() {
    tileData = std::make_shared<std::vector<uint32_t>>(100000000);
    for(auto &x: *tileData){ x = 1;}
}
std::shared_ptr<std::vector<uint32_t>>  TestLoader::test_func2() const{

    return tileData;
}

The interface code is like the following:

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;
PYBIND11_MODULE(fltest_lib, m) {
    py::class_<TestLoader,  std::shared_ptr<TestLoader>>(m, "TestLoader")
    .def(py::init<const std::string &>())
    .def("test_func1", &TestLoader::test_func1)
    .def("test_func2", &TestLoader::test_func2, py::return_value_policy::reference_internal);
}

However, this does not compile and I get a long error message. One particular line is the following:

/home/samee/fl_test/lib/pybind11/include/pybind11/cast.h:653:61: error: static assertion failed: Holder classes are only supported for custom types
  653 |     static_assert(std::is_base_of<base, type_caster<type>>::value,
      |                                                             ^~~~~

Any help to circumvent this will be really helpful.


Solution

  • According to this issue, it doesn't work because std::vector<uint32_t> is not converted to a python type. So, you will have to return the dereferenced vector. To avoid copies, you can use PYBIND11_MAKE_OPAQUE

    #include <pybind11/pybind11.h>
    #include <pybind11/stl.h>
    #include <pybind11/stl_bind.h>
    
    #include "test_loader.h"
    
    namespace py = pybind11;
    
    PYBIND11_MAKE_OPAQUE(std::vector<uint32_t>);
    
    PYBIND11_MODULE(fltest_lib, m) {
      py::bind_vector<std::vector<uint32_t>>(m, "VectorUInt32");
      py::class_<TestLoader, std::shared_ptr<TestLoader>>(m, "TestLoader")
          .def(py::init())
          .def("test_func1", &TestLoader::test_func1)
          .def("test_func2",
               [](const TestLoader& tl) -> const std::vector<uint32_t>& {
                 return *tl.test_func2();
               }, py::return_value_policy::reference_internal);
    }