Search code examples
pythondeep-learningneural-networkpytorchbackpropagation

How to transform output of NN, while still being able to train?


I have a neural network which outputs output. I want to transform output before the loss and backpropogation happen.

Here is my general code:

with torch.set_grad_enabled(training):
                  outputs = net(x_batch[:, 0], x_batch[:, 1]) # the prediction of the NN
                  # My issue is here:
                  outputs = transform_torch(outputs)
                  loss = my_loss(outputs, y_batch)

                  if training:
                      scheduler.step()
                      loss.backward()
                      optimizer.step()

Following the advice in How to transform output of neural network and still train? , I have a transformation function which I put my output through:

def transform_torch(predictions):
    new_tensor = []
    for i in range(int(len(predictions))):
      arr = predictions[i]
      a = arr.clone().detach() 
      
      # My transformation, which results in a positive first element, and the other elements represent decrements of the first positive element.
     
      b = torch.negative(a)
      b[0] = abs(b[0])
      new_tensor.append(torch.cumsum(b, dim = 0))

      # new_tensor[i].requires_grad = True
    new_tensor = torch.stack(new_tensor, 0)    

    return new_tensor

Note: In addition to clone().detach(), I also tried the methods described in Pytorch preferred way to copy a tensor, to similar result.

My problem is that no training actually happens with this tensor that is tranformed.

If I try to modify the tensor in-place (e.g. directly modify arr), then Torch complains that I can't modify a tensor in-place with a gradient attached to it.

Any suggestions?


Solution

  • Calling detach on your predictions stops gradient propagation to your model. Nothing you do after that can change your parameters.

    How about modifying your code to avoid this:

    def transform_torch(predictions):
      b = torch.cat([predictions[:, :1, ...].abs(), -predictions[:, 1:, ...]], dim=1)
      new_tensor = torch.cumsum(b, dim=1)
      return new_tensor
    

    A little test you can run, to verify that gradients do propagate through this transformation is:

    # start with some random tensor representing the input predictions
    # make sure it requires_grad
    pred = torch.rand((4, 5, 2, 3)).requires_grad_(True)
    # transform it
    tpred = transform_torch(pred)
    
    # make up some "default" loss function and back-prop
    tpred.mean().backward()
    
    # check to see all gradients of the original prediction:
    pred.grad
    # as you can see, all gradients are non-zero
    Out[]:
    tensor([[[[ 0.0417,  0.0417,  0.0417],
              [ 0.0417,  0.0417,  0.0417]],
    
             [[-0.0333, -0.0333, -0.0333],
              [-0.0333, -0.0333, -0.0333]],
    
             [[-0.0250, -0.0250, -0.0250],
              [-0.0250, -0.0250, -0.0250]],
    
             [[-0.0167, -0.0167, -0.0167],
              [-0.0167, -0.0167, -0.0167]],
    
             [[-0.0083, -0.0083, -0.0083],
              [-0.0083, -0.0083, -0.0083]]],
    
    
            [[[ 0.0417,  0.0417,  0.0417],
              [ 0.0417,  0.0417,  0.0417]],
    
             [[-0.0333, -0.0333, -0.0333],
              [-0.0333, -0.0333, -0.0333]],
    
             [[-0.0250, -0.0250, -0.0250],
              [-0.0250, -0.0250, -0.0250]],
    
             [[-0.0167, -0.0167, -0.0167],
              [-0.0167, -0.0167, -0.0167]],
    
             [[-0.0083, -0.0083, -0.0083],
              [-0.0083, -0.0083, -0.0083]]],
    
    
            [[[ 0.0417,  0.0417,  0.0417],
              [ 0.0417,  0.0417,  0.0417]],
    
             [[-0.0333, -0.0333, -0.0333],
              [-0.0333, -0.0333, -0.0333]],
    
             [[-0.0250, -0.0250, -0.0250],
              [-0.0250, -0.0250, -0.0250]],
    
             [[-0.0167, -0.0167, -0.0167],
              [-0.0167, -0.0167, -0.0167]],
    
             [[-0.0083, -0.0083, -0.0083],
              [-0.0083, -0.0083, -0.0083]]],
    
    
            [[[ 0.0417,  0.0417,  0.0417],
              [ 0.0417,  0.0417,  0.0417]],
    
             [[-0.0333, -0.0333, -0.0333],
              [-0.0333, -0.0333, -0.0333]],
    
             [[-0.0250, -0.0250, -0.0250],
              [-0.0250, -0.0250, -0.0250]],
    
             [[-0.0167, -0.0167, -0.0167],
              [-0.0167, -0.0167, -0.0167]],
    
             [[-0.0083, -0.0083, -0.0083],
              [-0.0083, -0.0083, -0.0083]]]])
    

    If you'll try this little test with your original code you'll either get an error that you are trying to propagate through tensors that do not require_grad, or you'll get no grads for the input pred.