Search code examples
pythonpytorchautograd

Do all variables in the loss function have to be tensor with grads in pytorch?


I have the following function


def msfe(ys, ts):
    ys=ys.detach().numpy() #output from the network
    ts=ts.detach().numpy() #Target (true labels)
    pred_class = (ys>=0.5) 
    n_0 = sum(ts==0) #Number of true negatives
    n_1 = sum(ts==1) #Number of true positives
    FPE = sum((ts==0)[[bool(p) for p in (pred_class==1)]])/n_0 #False positive error
    FNE = sum((ts==1)[[bool(p) for p in (pred_class==0)]])/n_1 #False negative error
    loss= FPE**2+FNE**2

    loss=torch.tensor(loss,dtype=torch.float64,requires_grad=True)


    return loss

and I wonder, if the autograd in Pytorch works properly, since ys and ts does not have the grad flag.

So my question is: do all the variables (FPE,FNE,ys,ts,n_1,n_0) have to be tensors, before optimizer.step() works, or is it okay that it is only the final function (loss) which is ?


Solution

  • All of the variables you want to optimise via optimizer.step() need to have gradient.

    In your case it would be y predicted by network, so you shouldn't detach it (from graph).

    Usually you don't change your targets, so those don't need gradients. You shouldn't have to detach them though, tensors by default don't require gradient and won't be backpropagated.

    Loss will have gradient if it's ingredients (at least one) have gradient.

    Overall you rarely need to take care of it manually.

    BTW. don't use numpy with PyTorch, there is rarely ever the case to do so. You can perform most of the operations you can do on numpy array on PyTorch's tensor.

    BTW2. There is no such thing as Variable in pytorch anymore, only tensors which require gradient and those that don't.

    Non-differentiability

    1.1 Problems with existing code

    Indeed, you are using functions which are not differentiable (namely >= and ==). Those will give you trouble only in the case of your outputs, as those required gradient (you can use == and >= for targets though).

    Below I have attached your loss function and outlined problems in it in the comments:

    # Gradient can't propagate if you detach and work in another framework
    # Most Python constructs should be fine, detaching will ruin it though.
    def msfe(outputs, targets):
        # outputs=outputs.detach().numpy() # Do not detach, no need to do that
        # targets=targets.detach().numpy() # No need for numpy either
        pred_class = outputs >= 0.5  # This one is non-differentiable
        # n_0 = sum(targets==0) # Do not use sum, there is pytorch function for that
        # n_1 = sum(targets==1)
    
        n_0 = torch.sum(targets == 0)  # Those are not differentiable, but...
        n_1 = torch.sum(targets == 1)  # It does not matter as those are targets
    
        # FPE = sum((targets==0)[[bool(p) for p in (pred_class==1)]])/n_0 # Do not use Python bools
        # FNE = sum((targets==1)[[bool(p) for p in (pred_class==0)]])/n_1 # Stay within PyTorch
        # Those two below are non-differentiable due to == sign as well
        FPE = torch.sum((targets == 0.0) * (pred_class == 1.0)).float() / n_0
        FNE = torch.sum((targets == 1.0) * (pred_class == 0.0)).float() / n_1
        # This is obviously fine
        loss = FPE ** 2 + FNE ** 2
    
        # Loss should be a tensor already, don't do things like that
        # Gradient will not be propagated, you will have a new tensor
        # Always returning gradient of `1` and that's all
        # loss = torch.tensor(loss, dtype=torch.float64, requires_grad=True)
    
        return loss
    

    1.2 Possible solution

    So, you need to get rid of 3 non-differentiable parts. You could in principle try to approximate it with continuous outputs from your network (provided you are using sigmoid as activation). Here is my take:

    def msfe_approximation(outputs, targets):
        n_0 = torch.sum(targets == 0)  # Gradient does not flow through it, it's okay
        n_1 = torch.sum(targets == 1)  # Same as above
        FPE = torch.sum((targets == 0) * outputs).float() / n_0
        FNE = torch.sum((targets == 1) * (1 - outputs)).float() / n_1
    
        return FPE ** 2 + FNE ** 2
    

    Notice that to minimize FPE outputs will try to be zero on the indices where targets are zero. Similarly for FNE, if targets are 1, network will try to output 1 as well.

    Notice similarity of this idea to BCELoss (Binary CrossEntropy).

    And lastly, example you can run this on, just for sanity check:

    if __name__ == "__main__":
        model = torch.nn.Sequential(
            torch.nn.Linear(30, 100),
            torch.nn.ReLU(),
            torch.nn.Linear(100, 200),
            torch.nn.ReLU(),
            torch.nn.Linear(200, 1),
            torch.nn.Sigmoid(),
        )
        optimizer = torch.optim.Adam(model.parameters())
        targets = torch.randint(high=2, size=(64, 1)) # random targets
        inputs = torch.rand(64, 30) # random data
        for _ in range(1000):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = msfe_approximation(outputs, targets)
            print(loss)
            loss.backward()
            optimizer.step()
    
        print(((model(inputs) >= 0.5) == targets).float().mean())