Search code examples
overloadingautopybind11

pybind11 with "auto" keyword for overloaded function


I would like help in wrapping an overloaded function that uses the return type of "auto". For example, the functions on lines 699 and 708 at https://github.com/microsoft/SEAL/blob/master/native/src/seal/ciphertext.h

        SEAL_NODISCARD inline auto &scale() noexcept
        {
            return scale_;
        }
        SEAL_NODISCARD inline auto &scale() const noexcept
        {
            return scale_;
        }

When I try to bind as follows,

py::class_<Ciphertext>(m, "Ciphertext")  
     .def("scale", (auto  (Ciphertext::*)() const)&Ciphertext::scale, "returns a constant reference to the scale")   

I see this error

...mseal.cpp:223:18: error: invalid use of ‘auto’
   .def("scale", (auto (Ciphertext::*)() const)&Ciphertext::scale, "returns a constant reference to the scale")

I am using C++17 and python3. I don't want to modify the C++ SEAL library. Thank you.


Solution

  • EDIT: I just found out that pybind11 has a helper construct to do the same thing, leading to much simpler/cleaner code. Replace the PYBIND11_MODULE block in the original answer with:

    PYBIND11_MODULE(mseal, m) {
        py::class_<Ciphertext>(m, "Ciphertext")
            .def(py::init<>())
            .def("scale",
                 py::overload_cast<>(&Ciphertext::scale, py::const_),
                 "returns a constant reference to the scale")
            .def("scale_nonconst",
                 py::overload_cast<>(&Ciphertext::scale),
                 "returns a reference to the scale");
    }
    

    Original answer: You will need decltype to get to the return type and std::declval to disambiguate the overload for decltype. A full working example follows (C++14 minimum), where I've added the non-const version just to show that you have full control over the selection of either:

    #include <pybind11/pybind11.h>
    #include <pybind11/pytypes.h>
    
    #include <iostream>
    #include <utility>
    
    namespace py = pybind11;
    
    class Ciphertext {
    public:
        inline auto &scale() noexcept {
            std::cerr << "non-const called" << std::endl;
            return scale_;
        }
        inline auto &scale() const noexcept {
             std::cerr << "const called" << std::endl;
           return scale_;
        }
    
    private:
        int scale_;
    };
    
    
    PYBIND11_MODULE(mseal, m) {
        py::class_<Ciphertext>(m, "Ciphertext")
            .def(py::init<>())
            .def("scale",
                 static_cast<decltype(std::declval<Ciphertext const&>().scale()) (Ciphertext::*)() const>(&Ciphertext::scale),
                 "returns a constant reference to the scale")
            .def("scale_nonconst",
                 static_cast<decltype(std::declval<Ciphertext&>().scale()) (Ciphertext::*)()>(&Ciphertext::scale),
                 "returns a reference to the scale");
    }
    

    Which, when compiled into mseal.so works as expected:

    >>> import mseal
    >>> mseal.Ciphertext().scale()
    const called
    0
    >>> mseal.Ciphertext().scale_nonconst()
    non-const called
    0
    >>>