I have a model that uses a custom LambdaLayer
as follows:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)
def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()
The model runs perfectly fine on CPU or 1 GPU. However, when running it with PyTorch Lightning over 2+ GPUs, this error happens:
AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'
The purpose of using a lambda function here is that given an inputs
tensor, I want to pass only inputs[:, start:end]
to the embedding
layer.
My questions:
So the problem isn't the lambda function per se, it's that pickle doesn't work with functions that aren't just module-level functions (the way pickle treats functions is just as references to some module-level name). So, unfortunately, if you need to capture the start
and end
arguments, you won't be able to use a closure, you'd normally just want something like:
def function_maker(start, end):
def function(x):
return x[:, start:end]
return function
But this will get you right back to where you started, as far as the pickling problem is concerned.
So, try something like:
class Slicer:
def __init__(self, start, end):
self.start = start
self.end = end
def __call__(self, x):
return x[:, self.start:self.end])
Then you can use:
LambdaLayer(Slicer(start, end))
I'm not familiar with PyTorch, I'm surprised though that it doesn't offer the ability to use a different serialization backend. The pathos/dill project can pickle arbitrary functions, for example, and is often easier to just use that. But I believe the above should solve the problem.