Search code examples
attributestorchresnet

How to fix AttributeError: 'torch.return_types.max' object has no attribute 'eq'


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_2828\2342855618.py in <module>
     21         predicted = torch.max(putputs.data,1)
     22         total += labels.size(0)
---> 23         correct += predicted.eq(labels.data).cpu().sum()
     24     print('[epoch:%d,iter:%d] loss: %0.3f | Acc: %.3f%%' %(epoch + 1,(i + 1 + epoch * length ),sum_loss/(i + 1),100.*correct/total))

AttributeError: 'torch.return_types.max' object has no attribute 'eq'

I have read the explanation of torch.max() on the website, but it still doesn't solve the problem enter image description here


Solution

  • torch.max returns a named tuple. If you are only interested in the maximum values you could do:

    predicted = torch.max(putputs.data,1).values
    

    If you need the indices of the maximum values you can use:

    predicted = torch.max(putputs.data,1).indices