Search code examples
pandasdaskpybind11

cannot pickle 'PyCapsule' object error, when using pybind11 function and dask


See the following example extracted from the pybind11 help (https://github.com/pybind/python_example):

The setup.py is:

import sys

# Available at setup time due to pyproject.toml
from pybind11 import get_cmake_dir
from pybind11.setup_helpers import Pybind11Extension, build_ext
from setuptools import setup

__version__ = "0.0.1"

# The main interface is through Pybind11Extension.
# * You can add cxx_std=11/14/17, and then build_ext can be removed.
# * You can set include_pybind11=false to add the include directory yourself,
#   say from a submodule.
#
# Note:
#   Sort input source files if you glob sources to ensure bit-for-bit
#   reproducible builds (https://github.com/pybind/python_example/pull/53)

ext_modules = [
    Pybind11Extension("python_example",
        ["src/main.cpp"],
        # Example: passing in the version to the compiled code
        define_macros = [('VERSION_INFO', __version__)],
        ),
]

setup(
    name="python_example",
    version=__version__,
    author="Sylvain Corlay",
    author_email="sylvain.corlay@gmail.com",
    url="https://github.com/pybind/python_example",
    description="A test project using pybind11",
    long_description="",
    ext_modules=ext_modules,
    extras_require={"test": "pytest"},
    # Currently, build_ext only provides an optional "highest supported C++
    # level" feature, but in the future it may provide more features.
    cmdclass={"build_ext": build_ext},
    zip_safe=False,
    python_requires=">=3.7",
)

The Cpp part is (src/main.cpp):

#include <pybind11/pybind11.h>

#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)

int add(int i, int j) {
    return i + j;
}

namespace py = pybind11;

PYBIND11_MODULE(python_example, m) {
    m.doc() = R"pbdoc(
        Pybind11 example plugin
        -----------------------

        .. currentmodule:: python_example

        .. autosummary::
        :toctree: _generate

        add
        subtract
    )pbdoc";

    m.def("add", &add, R"pbdoc(
        Add two numbers

        Some other explanation about the add function.
    )pbdoc");

    m.def("subtract", [](int i, int j) { return i - j; }, R"pbdoc(
        Subtract two numbers

        Some other explanation about the subtract function.
    )pbdoc");

#ifdef VERSION_INFO
    m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
#else
    m.attr("__version__") = "dev";
#endif
}

And the python code that I want to run is this (example.py):

import numpy as np
import pandas as pd
import dask.dataframe as dd
from dask.diagnostics import ProgressBar
from python_example import add


def python_add(i: int, j: int) -> int:
    return i + j


def add_column_values_python(row: pd.Series) -> pd.Series:
    row['sum'] = python_add(row['i'], row['j'])


def add_column_values(row: pd.Series) -> pd.Series:
    row['sum'] = add(int(row['i']), int(row['j']))


def main():
    dataframe = pd.read_csv('./example.csv', index_col=[])
    dataframe['sum'] = np.nan

    with ProgressBar():
        d_dataframe = dd.from_pandas(dataframe, npartitions=16)
        dataframe = d_dataframe.map_partitions(
            lambda df: df.apply(add_column_values_python, axis=1)).compute(scheduler='processes')

    with ProgressBar():
        d_dataframe = dd.from_pandas(dataframe, npartitions=16)
        dataframe = d_dataframe.map_partitions(
            lambda df: df.apply(add_column_values, axis=1), meta=pd.Series(dtype='float64')).compute(scheduler='processes')


if __name__ == '__main__':
    main()

And the example.csv file looks like this:

i,j
1,2
3,4
5,6
7,8
9,10

But when I run this code I get the following error when using the C++ add version:

[########################################] | 100% Completed | 1.24 ss
[                                        ] | 0% Completed | 104.05 ms
Traceback (most recent call last):
  File "/Users/user/local/src/python_example/example.py", line 38, in <module>
    main()
  File "/Users/user/local/src/python_example/example.py", line 33, in main
    dataframe = d_dataframe.map_partitions(
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/dask/base.py", line 314, in compute
    (result,) = compute(self, traverse=False, **kwargs)
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/dask/base.py", line 599, in compute
    results = schedule(dsk, keys, **kwargs)
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/dask/multiprocessing.py", line 233, in get
    result = get_async(
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/dask/local.py", line 499, in get_async
    fire_tasks(chunksize)
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/dask/local.py", line 481, in fire_tasks
    dumps((dsk[key], data)),
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/Users/user/local/src/python_example/.venv/lib/python3.9/site-packages/cloudpickle/cloudpickle_fast.py", line 632, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'PyCapsule' object

Is there a way to solve that, maybe by defining something in the C++ module definition?

Note that this example is only to illustrate the problem.


Solution

  • General

    If you want to pass an object from one process to another in python (with or without dask), you need a way to serialise it. The default method for this is "pickle". Objects in C libraries are fundamentally dynamic pointer-based things and pickle doesn't know what to do with them. You can implement the pickle protocol for your C object by providing getstate/setstate or reduce dunder methods.

    Alternatively, dask has a layer of serialisation where you can register specific ser/de functions for specific classes, but that is only with the distributed scheduler, not multiprocessing (the former is better in every way, there is no good reason you should be using multiprocessing).

    Specific

    A couple of simpler options:

    • use the threading scheduler, so that no serialisation is needed (C code ought to release the GIL and get full parallelism)
    • I think it's only the add function that is the problem, it's probably enough to move your import into the add_column_values function, so that each worker gets its own copy instead of passing it from the closure.