Search code examples
optimizationpytorchminimax

How to minimize wrt one set of parameters and maximize wrt other set of parameters simultaneously in a training loop in pytorch?


I have a loss function that includes two sets of parameters to learn. One is a matrix, wrt which I want to maximize the loss, other is the set of parameters for logistic regression, wrt which I want to minimize the loss. In pytorch whenever I use loss.backward(), the loss is minimized wrt both sets of parameters and (-loss).backward() maximizes wrt both. How do I do minimax optimization wrt the sets of parameters in pytorch? Tensorflow probably has this concept of gradient_tape and tape.watch() concept. What's the alternative in pytorch?


Solution

  • You can refer to the gradient reversal idea from https://arxiv.org/abs/1409.7495.

    But the crux of the idea is this: you have some loss function l(X,Y) where X and Y are parameters. Now you want to update X to minimize loss and update Y to maximize loss, which can be seen as minimizing -l(X,Y).

    Essentially you want to update parameters X with dl/dX and Y with d(-l)/dY = -dl/dy. You can do this by doing a backpropagation step, modifying the gradients of Y, and applying the update. In pytorch terms, that would be:

    loss = compute_loss()
    loss.backward()
    # modify gradients of Y
    Y.grad.data = -Y.grad.data
    optimizer.step()
    optimizer.zero_grad()