Search code examples
pythonserializationpickledill

Python: Pickle class object that has functions/callables as attributes


I have custom class which inherits from functools.partial

from functools import partial
from typing import Callable

class CustomPartial(partial):
    def __new__(cls, func_name: str, func: Callable, *args, **kwargs):
        self=super(CustomPartial, cls).__new__(cls, func, *args, **kwargs)
        self.func_name = func_name
        return self

    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)

This code is fine as is for serial processing, i.e. I can create objects of this class as needed and call them as normal functions.

The issue I'm running into though is when I try and use one of these CustomPartial objects as the function input for joblib Parallel processing. Based on the exceptions being throw

TypeError: CustomPartial.__new__() missing 1 required positional argument: 'func'

I've summarised that the issue is happening when trying to "un-serialize" between processes.

The code below is a minimum working example of the issue. I've tried to serialize using dill and tried implementing the __setstate__ / __getstate__ functions but nothing seems to be changing the exception being thrown.

import dill
from typing import Callable
from functools  import partial

class CustomPartial(partial):
    def __new__(cls, func_name: str, func: Callable, *args, **kwargs):
        self=super(CustomPartial, cls).__new__(cls, func, *args, **kwargs)
        self.func_name = func_name
        return self

    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)

add = lambda x, y: x+y

add_ten = partial(add, y=10)
custom_partial = CustomPartial('add_ten', add_ten)

print(dill.loads(dill.dumps(add_ten)))
# functools.partial(<function <lambda> at 0x7f7647eefa30>, y=10)

try:
    print(dill.loads(dill.dumps(custom_partial)))
except Error as err:
    print(err)
    # CustomPartial.__new__() missing 1 required positional argument: 'func'

Any help / direction towards resolving this issue would be greatly appreciated :)


Solution

  • Edit: the solution is complex because partial uses __setstate__()

    Didn't test it, but you probably need to override the method partial.__reduce__() in your CustomPartial class to match its __new__() signature with an extra argument.

    This is the partial.__reduce__() definition in Python 3.10:

    def __reduce__(self):
        return type(self), (self.func,), (self.func, self.args,
               self.keywords or None, self.__dict__ or None)
    

    You should include the extra argument/attribute in the second item of the returned tuple, which is passed as *args to __new__() when unpickling an object of this class. Plus, as partial uses __setstate__() to set its __dict__ attribute, you'll need to take care of that, otherwise the func_name attribute will be erased. If you use at least Python 3.8, and if you want to preserve the original __setstate__() method, you can use the sixth field of the reduce value to pass a callable that controls how the update is made.

    Try to add this to your class:

    def __reduce__(self):
        return (
            type(self),
            (self.func_name, self.func),
            (self.func, self.args, self.keywords or None, self.__dict__ or None),
            None,
            None,
            self._setstate
        )
    
    @staticmethod
    def _setstate(obj, state):
        func_name = obj.func_name
        obj.__setstate__(state)  # erases func_name
        obj.func_name = func_name
    

    Reference: https://docs.python.org/3/library/pickle.html#object.__reduce__