Search code examples
python-3.xpytorchtorch

Pytorch is throwing an error RuntimeError: result type Float can't be cast to the desired output type Long


How should I get rid of the following error?

>>> t = torch.tensor([[1, 0, 1, 1]]).T
>>> p = torch.rand(4,1)
>>> torch.nn.BCEWithLogitsLoss()(p, t)

The above code is throwing the following error:

RuntimeError: result type Float can't be cast to the desired output type Long


Solution

  • BCEWithLogitsLoss requires its target to be a float tensor, not long. So you should specify the type of t tensor by dtype=torch.float32:

    import torch
    
    t = torch.tensor([[1, 0, 1, 1]], dtype=torch.float32).T
    p = torch.rand(4,1)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    print(loss_fn(p, t))
    

    Output:

    tensor(0.5207)