PyTorch: Comparing predicted label and target label to compute accuracy

I'm trying to implement this loop to get the accuracy of my PyTorch CNN (The complete code of it is here) My version of the loop is so far:

correct = 0
    test_total = 0
    for itera, testdata2 in enumerate(test_loader, 0):
        test_images2, test_labels2 = testdata2
        if use_gpu:
            test_images2 = Variable(test_images2.cuda())
            test_images2 = Variable(test_images2)
        outputs = model(test_images2)
        _, predicted = torch.max(, 1)       
        test_total += test_labels2.size(0)      
        test_labels2 = test_labels2.type_as(predicted)
        correct += (predicted == test_labels2[0]).sum()    
    print('Accuracy of the network on all the test images: %d %%' % (
        100 * correct / test_total))

If I run it like this, I get:

> Traceback (most recent call last):   File
> "c:/python_code/Customized-DataLoader-master_two/",
> line 186, in <module>
>     main()   File "c:/python_code/Customized-DataLoader-master_two/",
> line 177, in main
>     correct += (predicted == test_labels2[0]).sum()   File "C:\anaconda\envs\pytorch_cuda\lib\site-packages\torch\",
> line 360, in __eq__
>     return self.eq(other) RuntimeError: invalid argument 3: sizes do not match at
> c:\anaconda2\conda-bld\pytorch_1519501749874\work\torch\lib\thc\generated\../THCTensorMathCompareT.cuh:65

I used test_labels2 = test_labels2.type_as(predicted) to have both tensors as LongTensors, which seems to work fine to avert the "Expected this...but got..." errors. They look like this now:

test_labels2 after conversion:
 0  1
 1  0
 1  0
[torch.cuda.LongTensor of size 3x2 (GPU 0)]

[torch.cuda.LongTensor of size 3 (GPU 0)]

I supppose the problem now is, that test_labels2[0] is returning a row but not the column.

How do I get this to work?


  • Indexing in pytorch works mostly like indexing in numpy. To index all rows of a certain column j use:

    tensor[:, j]

    Alternatively, the select function from pytorch can be used.