Search code examples
pythonpicklewrapperpython-decoratorsnested-function

Alternative to nested functions for pickling


I have a piece of code that generates a function from many smaller functions while making the outermost one accept an argument x.

In other words, I have an input x and I need to do various transformations to x that are decided at runtime.

This is done by iteratively calling this function (it essentially wraps a function in another function).

Here is the function:

def build_layer(curr_layer: typing.Callable, prev_layer: Union[typing.Callable, int]) -> typing.Callable:

    def _function(x):
        return curr_layer(prev_layer(x) if callable(prev_layer) else x)

    return _function

Sidenote: as you can see if prev_layer is not callable it gets substituted for input x so I am using dummy integers to indicate where input goes.

The problem: this code cannot be pickled. I do not seem to be able to figure out a way to rewrite this code in such a way to be pickleable.

Note: I need this object to be persisted on disk, but also its used in multiprocessing where its pickled for IPC (these functions are not used there, so technically they could be moved)

I have also a more complex version of this function that handles multiple inputs (using fixed aggregation function, in this case torch.cat) I know these two can be merged into one generic function and I will do that once I get it to work.

Here is the code for the second function:

def build_layer_multi_input(curr_layer: typing.Callable, prev_layers: list) -> typing.Callable:

    def _function(x):
        return curr_layer(torch.cat([layer(x) if callable(layer) else x for layer in prev_layers]))

    return _function

Solution

  • I resolved this by attaching the return value of these functions to a class instance as described in this thread.