Search code examples
pythonoptimizationneural-networkdeep-learningpytorch

How to print the "actual" learning rate in Adadelta in pytorch


In short:

I can't draw lr/epoch curve when using adadelta optimizer in pytorch because optimizer.param_groups[0]['lr'] always return the same value.

In detail:

Adadelta can dynamically adapts over time using only first order information and has minimal computational overhead beyond vanilla stochastic gradient descent [1].

In pytorch, the source code of Adadelta is here https://pytorch.org/docs/stable/_modules/torch/optim/adadelta.html#Adadelta

Since it requires no manual tuning of learning rate, in my knowledge, we don't have to set any schedular after declare the optimizer

self.optimizer = torch.optim.Adadelta(self.model.parameters(), lr=1)

The way to check learning rate is

current_lr = self.optimizer.param_groups[0]['lr']

The problem is it always return 1 (the initial lr).

Could anyone tell me how can I get the true learning rate so that can I draw a lr/epch curve?

[1] https://arxiv.org/pdf/1212.5701.pdf


Solution

  • Check: self.optimizer.state. This is optimized with the lr and used in optimization process.

    From documentation a lr is just:

    lr (float, optional): coefficient that scale delta before it is applied to the parameters (default: 1.0)

    https://pytorch.org/docs/stable/_modules/torch/optim/adadelta.html

    Edited: you may find acc_delta values in self.optimizer.state values but you need to go through dictionaries contained by this dictionary:

    dict_with_acc_delta = [self.optimizer.state[i] for i in self.optimizer.state.keys() if "acc_delta" in self.optimizer.state[i].keys()]
    acc_deltas = [i["acc_delta"] for i in dict_with_acc_delta]
    

    I have eight layers and shapes of elements in the acc_deltas list are following

    [torch.Size([25088]),
     torch.Size([25088]),
     torch.Size([4096, 25088]),
     torch.Size([4096]),
     torch.Size([1024, 4096]),
     torch.Size([1024]),
     torch.Size([102, 1024]),
     torch.Size([102])]