Search code examples
pythontensorflowmatrixkerasconfusion-matrix

Confusion matrix with a Keras Model: A question, perhaps does anybody know how to do a confusion matrix for this model?


I am currently doing a research for a problem of classification, but I don't know how to make a confusion matrix for this model, which I send the code below. I am using Keras library on collab, because my local environment was not compatible with tf.keras.preprocessing.image_dataset_from_directory, I get a good percentage of accuracy, and good predictions but I am lost when I am doing the matrix. Thank in advance to the community

image_size = (180, 180)
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/gdrive/MyDrive/Collab Notebooks/LungCells/seg_train",
    validation_split=0.2,
    subset="training",
    seed=42,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "/content/gdrive/MyDrive/Collab Notebooks/LungCells/seg_train",
    validation_split=0.2,
    subset="validation",
    #seed=1337,
    seed=42,
    image_size=image_size,
    batch_size=batch_size,
)

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    # Image augmentation block
    x = data_augmentation(inputs)

    # Entry block
    x = layers.experimental.preprocessing.Rescaling(1.0 / 255)(x)
    x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [128, 256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(units, activation=activation)(x)
    return keras.Model(inputs, outputs)


model = make_model(input_shape=image_size + (3,), num_classes=2)
keras.utils.plot_model(model, show_shapes=True)

epochs = 50
callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5"),
]
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
model.fit(
    train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)

Solution

  • Maybe I fully don't understand your exact problem. But here is a similar working example that might come to you helpful.

    Let's say I will train a model on MNIST as a binary classifier (same as yours), whether a digit is odd or even and following by confusion matrix and classification report on them.

    DataSet

    import tensorflow as tf
    import numpy as np
    
    (xtrain, ytrain), (_, _) = tf.keras.datasets.mnist.load_data()
    
    # 2 categories, if the digit is odd or not
    xtrain = np.expand_dims(xtrain, axis=-1)
    xtrain = np.repeat(xtrain, 3, axis=-1)
    xtrain = xtrain.astype('float32') / 255
    
    # label 
    ytrain = tf.keras.utils.to_categorical((ytrain % 2 == 0).astype(int), 
                                            num_classes=2)
    
    print(xtrain.shape, ytrain.shape)
    (60000, 28, 28, 3) (60000, 2)
    

    Model

    # declare input shape 
    input = tf.keras.Input(shape=(28,28,3))
    # Block 1
    x = tf.keras.layers.Conv2D(32, 3, strides=2, activation="relu")(input)
    x = tf.keras.layers.MaxPooling2D(3)(x)
    
    # Now that we apply global max pooling.
    gap = tf.keras.layers.GlobalMaxPooling2D()(x)
    
    # Finally, we add a classification layer.
    output = tf.keras.layers.Dense(2, activation='softmax')(gap)
    
    # bind all
    model = tf.keras.Model(input, output)
    

    Compile and Run

    model.compile(
              loss      = tf.keras.losses.CategoricalCrossentropy(),
              metrics   = tf.keras.metrics.CategoricalAccuracy(),
              optimizer = tf.keras.optimizers.Adam())
    # fit 
    model.fit(xtrain, ytrain, batch_size=128, epochs=3, verbose = 2)
    

    Classification Report

    from sklearn.metrics import classification_report, confusion_matrix
    
    # as I've trained my model on MNIST as odd or even (binary classes)
    target_names = ['odd', 'even']
    
    # get predict prob and label 
    ypred = model.predict(xtrain, verbose=1)
    ypred = np.argmax(ypred, axis=1)
    
    print(classification_report(np.argmax(ytrain, axis=1), ypred, target_names=target_names))
    
                  precision    recall  f1-score   support
    
             odd       0.80      0.72      0.75     30508
            even       0.73      0.81      0.77     29492
    
        accuracy                           0.76     60000
       macro avg       0.77      0.76      0.76     60000
    weighted avg       0.77      0.76      0.76     60000
    

    Confusion Matrix

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd 
    
    cm = confusion_matrix(np.argmax(ytrain, axis=1), ypred)
    cm = pd.DataFrame(cm, range(2),range(2))
    plt.figure(figsize = (10,10))
    
    sns.heatmap(cm, annot=True, annot_kws={"size": 12}) # font size
    plt.show()
    

    enter image description here