Search code examples
optimizationpytorch

Doing Gradient Descent Ascent using PyTorch


I want to perform GDA for minimax problems of form $\min_x\max_y f(x,y)$. This is the same objective as that in GANs but GDA takes gradient steps simultaneously. In particular, we need

$$ x_{k+1} = x_k - \nabla_x f(x_k,y_k); y_{k+1} = y_k + \nabla_y f(x_k,y_k) $$

How can I achieve this using pytorch? It does not seem to be possible right now the way the optimizers are implemented to be always doing gradient descent. I thought of the following method, which seems like a cheap hack to be honest.

descent_optim = torch.optim.SGD([X], lr=lr)
ascent_optim = torch.optim.SGD([y], lr=lr)
while True:
  x_copy = x.detach().clone()
  loss = f(x,y)
  descent_optim.zero_grad()
  loss.backward()
  descent_optim.step()
  loss = - f(x_copy,y)
  ascent_optim.zero_grad()
  loss.backward()
  ascent_optim.step()

Solution

  • Pytorch do support using an Optimizer to maximize a loss via setting maximize=True. You can see in here.