Search code examples
python-3.xkerasconfusion-matrix

Confusion Matrix in Keras+Tensorflow


Q1

I have trained a CNN model and saved that as model.h5. I am trying to detect 3 objects. Say, "cat", "dog" and "other". My test set has 300 images, 100 from each category. First 100 is "cat", 2nd 100 is "dog" and 3rd 100 is "other". I am using Keras class ImageDataGenerator and flow_from_directory. Here is sample code:

test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='sparse',
        shuffle=False)

Now to use

from sklearn.metrics import confusion_matrix

cnf_matrix = confusion_matrix(y_test, y_pred)

I need y_test and y_pred. I can get y_pred using following code:

probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)

[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
 0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2]

Which is basically predicting the objects as 0,1 and 2. Now I know that first 100 object (cat) is 0, 2nd 100 object (dog) is 1 and 3rd 100 object (other) is 2. Do I create a list manually using numpy where first 100 point is 0, 2nd 100 point is 1 and 3rd 100 point is 2 to get y_test? Is there any Keras class that can do it (create y_test)?

Q2

How can I see the wrongly detected objects. If you look into print(y_pred), 3rd point is 1, which is wrongly predicted. How can see that image without going into my "test_dir" folder manually?


Solution

  • Since you're not using any augmentation and shuffle=False, you can simply get the images from the generator:

    imgBatch = next(test_generator)
        #it may be interesting to create the generator again if 
        #you're not sure it has output exactly all images before
    

    Plot each image in imgBatch using a plotting library, such as Pillow (PIL) or MatplotLib.

    For plotting only the desired images, compare y_test with y_pred:

    compare = y_test == y_pred
    
    position = 0
    while position < len(y_test):
        imgBatch = next(test_generator)
        batch = imgBatch.shape[0]
    
        for i in range(position,position+batch):
            if compare[i] == False:
                plot(imgBatch[i-position])
    
        position += batch