I am using a function (more specifically optim.lr_scheduler.LambdaLR from torch) where a Lambda function is passed as parameter.
from torch import optim
lambda1 = lambda epoch: epoch /10
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
As long as the lambda function is simple I can handle it fine. But I found myself in need of using a normal function instead and I passed the function name as parameter.
def func1(epoch, base=10):
return epoch/base
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=func1)
which does the same thing exactly as the lambda function above but it can be more easily expanded to a complex function (as I needed it to be). The problem is I do not know how to actually pass the parameter (base in this case) to the function (optim.lr_scheduler.LambdaLR in this case) accepting the lambda function.
Is there a way to pass this parameter, in other words to pass base to optim.lr_scheduler.LambdaLR through func1?
Well, inspired by @ekhumoro's comment I made a factory function which takes as argument just the parameters I want to control (in this case just base
and not x
I mean).
Edit: made some adjustment to get the desired function with the correct parameters (as @Lourenço commented).
def get_func1(base=10):
def func1(epoch, base=base):
return epoch/base
return func1
func1 = get_func1(base=20)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=func1)
So, the idea is that now the returned func1
function inside get_func1
function has the correct parameter for base
(and, thus, potentially be more complex with messing with the peculiarities of being a lambda function) while on the same time I can pass it by solely its name (now it should be get_func1
) and be used in the place of lambda function.