Search code examples
c++pybind11

Force pybind11's stubgen to produce a Sequence type from an std::vector


In our project, we are heavily using C++ code bound into Python, with stubgen generating the typed signatures of functions. Many of our Python APIs make use of the Sequence[T], which are used as inputs to C++ pybind functions (usually taking in an const std::vector<T>&).

E.g., from a class like this:

  py::class_<Dummy>(m, "Dummy")
      .def(py::init<const std::vector<std::string>&>, py::arg("names"));

Stubgen then generates APIs like this:

class Dummy:
    def __init__(self, names: List[str]) -> None: ...

Is there any way to convince stubgen to use a Sequence[T] instead? Changing the C++ signature to match the requirements is possible as well.

Edit: Added the C++ pybind call to satisfy @DanMašek's request below.


Solution

  • From what I can gather, the stub generator parses the signatures from doc-strings that are by default automatically generated by Pybind11. So, we should focus on Pybind11 itself.

    NB: I'm making the assumption that you're satisfied with how your current implementation works, all you care is the type annotations. I'm using Pybind11 version 2.9.2 with Python 3.7.

    Baseline

    Based on the information you've provided, I've created the following minimal example module:

    #include <pybind11/pybind11.h>
    #include <pybind11/stl.h>
    #include <vector>
    
    namespace py = pybind11;
    
    size_t test_in(std::vector<std::string> const& names)
    {
        return names.size();
    }
    
    std::vector<std::string> test_out(size_t count)
    {
        return std::vector<std::string>(count, "X");
    }
    
    PYBIND11_MODULE(so07, m)
    {
        m.def("test_in", test_in, py::arg("names"));
        m.def("test_out", test_out, py::arg("count"));
    }
    

    And checked it out using this simple Python script:

    import so07
    print(so07.test_in(["a","b","c"]))
    print(so07.test_out(3))
    print(type(so07.test_out(3)))
    help(so07.test_in)
    help(so07.test_out)
    

    The script produces the following output:

    3
    ['X', 'X', 'X']
    <class 'list'>
    Help on built-in function test_in in module so07:
    
    test_in(...) method of builtins.PyCapsule instance
        test_in(names: List[str]) -> int
    
    Help on built-in function test_out in module so07:
    
    test_out(...) method of builtins.PyCapsule instance
        test_out(count: int) -> List[str]
    

    Finally, the following type annotations are produced by running the stub generator on the module:

    from __future__ import annotations
    __all__ = ['test_in', 'test_out']
    def test_in(names: list[str]) -> int:
        ...
    def test_out(count: int) -> list[str]:
        ...
    

    Investigation

    Based on the output of help shown earlier, the type annotations come from Pybind11 itself. The specific names used for each argument or result are part of the type caster used for the specific C++ type. In this case, it's this line from pybind11/stl.h:

    PYBIND11_TYPE_CASTER(Type, const_name("List[") + value_conv::name + const_name("]"));
    

    The way the PYBIND11_TYPE_CASTER macro is defined, the name just becomes a constexpr member variable of the type caster, so there doesn't seem to be any way to change it as is.

    Solution

    Since the pybind11/stl.h header is only included explicitly when you want to have automatic conversion of the common STL collections, the most practical solution seems to be to not use that header, and instead provide your own type casters. This is rather easy, since you just have to swipe the relevant pieces of the pybind11/stl.h header and change that string to match your desires.

    For example, to support std::vectors, I based this sequence_caster on the original list_caster provided by Pybind11:

    #include <pybind11/pybind11.h>
    
    namespace py = pybind11;
    
    /* ============= BEGIN CONTENT SWIPED FROM pybind11/stl.h =================== */
    
    namespace PYBIND11_NAMESPACE {
    namespace detail {
    
    /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
    /// forwarding a container element).  Typically used indirect via forwarded_type(), below.
    template <typename T, typename U>
    using forwarded_type = conditional_t<std::is_lvalue_reference<T>::value,
        remove_reference_t<U> &,
        remove_reference_t<U> &&>;
    
    /// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
    /// used for forwarding a container's elements.
    template <typename T, typename U>
    forwarded_type<T, U> forward_like(U &&u)
    {
        return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
    }
    
    template <typename Type, typename Value>
    struct sequence_caster
    {
        using value_conv = make_caster<Value>;
    
        bool load(handle src, bool convert)
        {
            if (!isinstance<sequence>(src) || isinstance<bytes>(src) || isinstance<str>(src)) {
                return false;
            }
            auto s = reinterpret_borrow<sequence>(src);
            value.clear();
            reserve_maybe(s, &value);
            for (auto it : s) {
                value_conv conv;
                if (!conv.load(it, convert)) {
                    return false;
                }
                value.push_back(cast_op<Value &&>(std::move(conv)));
            }
            return true;
        }
    
    private:
        template <
            typename T = Type,
            enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
            void reserve_maybe(const sequence &s, Type *)
        {
            value.reserve(s.size());
        }
        void reserve_maybe(const sequence &, void *) {}
    
    public:
        template <typename T>
        static handle cast(T &&src, return_value_policy policy, handle parent)
        {
            if (!std::is_lvalue_reference<T>::value) {
                policy = return_value_policy_override<Value>::policy(policy);
            }
            list l(src.size());
            ssize_t index = 0;
            for (auto &&value : src) {
                auto value_ = reinterpret_steal<object>(
                    value_conv::cast(forward_like<T>(value), policy, parent));
                if (!value_) {
                    return handle();
                }
                PyList_SET_ITEM(l.ptr(), index++, value_.release().ptr()); // steals a reference
            }
            return l.release();
        }
    
        PYBIND11_TYPE_CASTER(Type, const_name("Sequence[") + value_conv::name + const_name("]"));
    };
    
    template <typename Type, typename Alloc>
    struct type_caster<std::vector<Type, Alloc>>
        : sequence_caster<std::vector<Type, Alloc>, Type> {};
    
    }} // namespace PYBIND11_NAMESPACE::detail
    
    /* ============= END CONTENT SWIPED FROM pybind11/stl.h =================== */
    
    size_t test_in(std::vector<std::string> const& names)
    {
        return names.size();
    }
    
    std::vector<std::string> test_out(size_t count)
    {
        return std::vector<std::string>(count, "X");
    }
    
    PYBIND11_MODULE(so07, m)
    {
        m.def("test_in", &test_in, py::arg("names"));
        m.def("test_out", test_out, py::arg("count"));
    }
    

    Now, the Python test script I used earlier produces this output:

    3
    ['X', 'X', 'X']
    <class 'list'>
    Help on built-in function test_in in module so07:
    
    test_in(...) method of builtins.PyCapsule instance
        test_in(names: Sequence[str]) -> int
    
    Help on built-in function test_out in module so07:
    
    test_out(...) method of builtins.PyCapsule instance
        test_out(count: int) -> Sequence[str]
    

    Notice, how the help now shows Sequence[str] instead of List[str].

    Furthermore the following annotations are generated by stubgen:

    from __future__ import annotations
    import typing
    __all__ = ['test_in', 'test_out']
    def test_in(names: typing.Sequence[str]) -> int:
        ...
    def test_out(count: int) -> typing.Sequence[str]:
        ...