Search code examples
pythonc++mpipybind11

pybind11: send MPI communicator from Python to CPP


I have a C++ class which I intend to call from python's mpi4py interface such that each node spawns the class. On the C++ side, I'm using the Open MPI library (installed via homebrew) and pybind11.

The C++ class is as follows:

#include <pybind11/pybind11.h>
#include <iostream>
#include <chrono>
#include <thread>
#include <vector>
#include <mpi.h>
// #define PyMPI_HAVE_MPI_Message 1
// #include <mpi4py/mpi4py.h>


namespace py = pybind11;

class SomeComputation
{
    float multiplier;
    std::vector<int> test;
    MPI_Comm comm_;

public:
    void Init()
    {
        int rank;
        MPI_Comm_rank(comm_, &rank);
        test.clear();
        test.resize(10, rank);
    }

    void set_comm(MPI_Comm comm){
        this->comm_ = comm;
    }

    SomeComputation(float multiplier_) : multiplier(multiplier_){}
    ~SomeComputation() { std::cout << "Destructor Called!\n"; }


    float compute(float input)
    {
        std::this_thread::sleep_for(std::chrono::milliseconds((int)input * 10));
        for (int i = 0; i != 10; ++i)
        {
            std::cout << test[i] << " ";
        }
        std::cout << std::endl;
        return multiplier * input;
    }
};

PYBIND11_MODULE(module_name, handle)
{
    py::class_<SomeComputation>(handle, "Cpp_computation")
        .def(py::init<float>()) // args of constructers are template args
        .def("set_comm", &SomeComputation::set_comm)  
        .def("compute", &SomeComputation::compute)
        .def("cpp_init", &SomeComputation::Init);
}

and here's the python interface spawning the same C++:

from build.module_name import * 
import time

from mpi4py import MPI


comm = MPI.COMM_WORLD
rank = comm.Get_rank()


m = Cpp_computation(44.0) # send communicator to cpp
m.cpp_init()
i = 0
while i < 5:
    print(m.compute(i))
    time.sleep(1)
    i+=1

I've already tried "Sharing an MPI communicator using pybind11" but I'm stuck at a long unhelpful error (full message):

[...]
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/pybind11.h:1398:22:   required from 'pybind11::class_<type_, options>& pybind11::class_<type_, options>::def(const char*, Func&&, const Extra& ...) [with Func = void (SomeComputation::*)(ompi_communicator_t*); Extra = {}; type_ = SomeComputation; options = {}]'
/Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:79:7:   required from here
/opt/homebrew/Cellar/gcc/11.2.0_3/include/c++/11/type_traits:1372:38: error: invalid use of incomplete type 'struct ompi_communicator_t'
 1372 |     : public integral_constant<bool, __is_base_of(_Base, _Derived)>
      |                                      ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from /Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:6:
/opt/homebrew/Cellar/open-mpi/4.1.2/include/mpi.h:419:16: note: forward declaration of 'struct ompi_communicator_t'
  419 | typedef struct ompi_communicator_t *MPI_Comm;
      |                ^~~~~~~~~~~~~~~~~~~

[...]

/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/pybind11.h:1398:22:   required from 'pybind11::class_<type_, options>& pybind11::class_<type_, options>::def(const char*, Func&&, const Extra& ...) [with Func = void (SomeComputation::*)(ompi_communicator_t*); Extra = {}; type_ = SomeComputation; options = {}]'
/Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:79:7:   required from here
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/detail/descr.h:40:19: error: invalid use of incomplete type 'struct ompi_communicator_t'
   40 |         return {{&typeid(Ts)..., nullptr}};
      |                   ^~~~~~~~~~
In file included from /Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:6:
/opt/homebrew/Cellar/open-mpi/4.1.2/include/mpi.h:419:16: note: forward declaration of 'struct ompi_communicator_t'
  419 | typedef struct ompi_communicator_t *MPI_Comm;
      |                ^~~~~~~~~~~~~~~~~~~

[...]

                 from /Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:1:
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/detail/descr.h:40:42: error: could not convert '{{<expression error>, nullptr}}' from '<brace-enclosed initializer list>' to 'std::array<const std::type_info*, 3>'
   40 |         return {{&typeid(Ts)..., nullptr}};
      |                                          ^
      |                                          |
      |                                          <brace-enclosed initializer list>

[...]

In file included from /Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:1:
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/pybind11.h: In instantiation of 'void pybind11::cpp_function::initialize(Func&&, Return (*)(Args ...), const Extra& ...) [with Func = pybind11::cpp_function::cpp_function<void, SomeComputation, ompi_communicator_t*, pybind11::name, pybind11::is_method, pybind11::sibling>(void (SomeComputation::*)(ompi_communicator_t*), const pybind11::name&, const pybind11::is_method&, const pybind11::sibling&)::<lambda(SomeComputation*, ompi_communicator_t*)>; Return = void; Args = {SomeComputation*, ompi_communicator_t*}; Extra = {pybind11::name, pybind11::is_method, pybind11::sibling}]':
[..]
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/pybind11.h:1398:22:   required from 'pybind11::class_<type_, options>& pybind11::class_<type_, options>::def(const char*, Func&&, const Extra& ...) [with Func = void (SomeComputation::*)(ompi_communicator_t*); Extra = {}; type_ = SomeComputation; options = {}]'
/Users/purusharth/Documents/hiwi/pympicontroller/main.cpp:79:7:   required from here
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/pybind11.h:266:73:   in 'constexpr' expansion of 'pybind11::detail::descr<18, SomeComputation, ompi_communicator_t>::types()'
/Users/purusharth/Documents/hiwi/pympicontroller/pybind11/include/pybind11/pybind11.h:266:39: error: 'constexpr' call flows off the end of the function
  266 |         PYBIND11_DESCR_CONSTEXPR auto types = decltype(signature)::types();
      |                                       ^~~~~

The error points to .def("set_comm", &SomeComputation::set_comm)

What is the cause of these errors, and how should they be resolved?

UPDATE: Added answer below by using custom type caster as explained in this answer. But is it the only way to go about it?


Solution

  • Based on this answer: https://stackoverflow.com/a/62449190/4593199

    I was able to transfer MPI Communicator by creating custom MPI type caster.

    #include <pybind11/pybind11.h>
    #include <mpi.h>
    #include <mpi4py/mpi4py.h>
    
    namespace py = pybind11;
    
    struct mpi4py_comm {
      mpi4py_comm() = default;
      mpi4py_comm(MPI_Comm value) : value(value) {}
      operator MPI_Comm () { return value; }
    
      MPI_Comm value;
    };
    
    
    namespace pybind11 { namespace detail {
      template <> struct type_caster<mpi4py_comm> {
        public:
          PYBIND11_TYPE_CASTER(mpi4py_comm, _("mpi4py_comm"));
    
          // Python -> C++
          bool load(handle src, bool) {
            PyObject *py_src = src.ptr();
    
            // Check that we have been passed an mpi4py communicator
            if (PyObject_TypeCheck(py_src, &PyMPIComm_Type)) {
              // Convert to regular MPI communicator
              value.value = *PyMPIComm_Get(py_src);
            } else {
              return false;
            }
    
            return !PyErr_Occurred();
          }
    
          // C++ -> Python
          static handle cast(mpi4py_comm src,
                             return_value_policy /* policy */,
                             handle /* parent */)
          {
            // Create an mpi4py handle
            return PyMPIComm_New(src.value);
          }
      };
    }} // namespace pybind11::detail
    
    
    // recieve a communicator and check if it equals MPI_COMM_WORLD
    void print_comm(mpi4py_comm comm)
    {
            int rank;
            std::vector<int> test; 
            MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    
            test.clear();
            test.resize(10, rank); 
    
            for (int i = 0; i != 10; ++i) {
                std::cout << test[i] << " ";
            }
            std::cout << std::endl;
    }
    
    
    class SomeComputation
    {
        float multiplier;
        std::vector<int> test;
        MPI_Comm comm_;
    
    public:
        void Init()
        {
            int rank;
            MPI_Comm_rank(comm_, &rank);
            test.clear();
            test.resize(10, rank);
        }
        SomeComputation(float multiplier_) : multiplier(multiplier_){}
        ~SomeComputation() { std::cout << "Destructor Called!\n"; }
    
        void set_comm(mpi4py_comm comm){
            this->comm_ = comm;
        }
    
        float compute(float input)
        {
            // std::this_thread::sleep_for(std::chrono::milliseconds((int)input * 10));
            for (int i = 0; i != 10; ++i)
            {
                std::cout << test[i] << " ";
            }
            std::cout << std::endl;
            return multiplier * input;
        }
    };
    
    
    mpi4py_comm get_comm()
    {
      return MPI_COMM_WORLD; // Just return MPI_COMM_WORLD for demonstration
    }
    
    PYBIND11_MODULE(native, m)
    {
      // import the mpi4py API
      if (import_mpi4py() < 0) {
        throw std::runtime_error("Could not load mpi4py API.");
      }
    
      // register the test functions
      m.def("print_comm", &print_comm, "Do something with the mpi4py communicator.");
      m.def("get_comm", &get_comm, "Return some communicator.");
    
    
        py::class_<SomeComputation>(m, "Cpp_computation")
            .def(py::init<float>()) // args of constructers are template args
            .def("set_comm", &SomeComputation::set_comm)
            .def("compute", &SomeComputation::compute)
            .def("cpp_init", &SomeComputation::Init);
    }
    

    This compiled and ran successfully, however, is there a more elegant way to go about it?