Search code examples
pythonlambdapytorchpickle

PyTorch can't pickle lambda


I have a model that uses a custom LambdaLayer as follows:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)


class TorchCatEmbedding(LightningModule):
    def __init__(self, start, end):
        super(TorchCatEmbedding, self).__init__()
        self.lb = LambdaLayer(lambda x: x[:, start:end])
        self.embedding = torch.nn.Embedding(50, 5)

    def forward(self, inputs):
        o = self.lb(inputs).to(torch.int32)
        o = self.embedding(o)
        return o.squeeze()

The model runs perfectly fine on CPU or 1 GPU. However, when running it with PyTorch Lightning over 2+ GPUs, this error happens:

AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'

The purpose of using a lambda function here is that given an inputs tensor, I want to pass only inputs[:, start:end] to the embedding layer.

My questions:

  • is there an alternative to using a lambda in this case?
  • if not, what should be done to get the lambda function to work in this context?

Solution

  • So the problem isn't the lambda function per se, it's that pickle doesn't work with functions that aren't just module-level functions (the way pickle treats functions is just as references to some module-level name). So, unfortunately, if you need to capture the start and end arguments, you won't be able to use a closure, you'd normally just want something like:

    def function_maker(start, end):
        def function(x):
            return x[:, start:end]
        return function
    

    But this will get you right back to where you started, as far as the pickling problem is concerned.

    So, try something like:

    class Slicer:
        def __init__(self, start, end):
            self.start = start
            self.end = end
        def __call__(self, x):
            return x[:, self.start:self.end])
    

    Then you can use:

    LambdaLayer(Slicer(start, end))
    

    I'm not familiar with PyTorch, I'm surprised though that it doesn't offer the ability to use a different serialization backend. The pathos/dill project can pickle arbitrary functions, for example, and is often easier to just use that. But I believe the above should solve the problem.