Search code examples
tensorflowdeep-learningobject-detectionbounding-box

How To Visualize A Trained Model With Bounding Boxes For Object Detection


I am trying to plot flower images with both the label and prediction that have a bounding box for each. I am using some lower layers of a pre-trained Xception model.

I have set the output layers to be 4 as there will be four coordinates for the bounding box:

loc_output = keras.layers.Dense(4)(avg)

For simplicity, I just set the four coordinates for the label as random numbers using tf.random.uniform.

How do I write a function using matplotlib that generates something like this:

import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

dataset, info = tfds.load("tf_flowers", as_supervised=True, with_info=True)

test_set_raw, valid_set_raw, train_set_raw = tfds.load(
    "tf_flowers",
    split=["train[:10%]", "train[10%:25%]", "train[25%:]"],
    as_supervised=True)

class_names = info.features["label"].names
n_classes = info.features["label"].num_classes


## Shuffle & Preprocess
def preprocess(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    final_image = keras.applications.xception.preprocess_input(resized_image)
    return final_image, label

batch_size = 32
train_set = train_set_raw.shuffle(1000).repeat()
train_set = train_set.map(preprocess).batch(batch_size).prefetch(1)
valid_set = valid_set_raw.map(preprocess).batch(batch_size).prefetch(1)
test_set = test_set_raw.map(preprocess).batch(batch_size).prefetch(1)

base_model = keras.applications.xception.Xception(weights="imagenet",
include_top=False) # Reuse lower layers of pretrained Xception model 
avg = keras.layers.GlobalAveragePooling2D()(base_model.output)
class_output = keras.layers.Dense(n_classes, activation="softmax")(avg)
loc_output = keras.layers.Dense(4)(avg) # 4 coordinates for our bounding box
model = keras.models.Model(inputs=base_model.input, outputs=[class_output, loc_output])

# for layer in base_model.layers:
#     layer.trainable = False
optimizer = keras.optimizers.SGD(lr=0.2, momentum=0.9, decay=0.01)
model.compile(loss=["sparse_categorical_crossentropy", "mse"],
              loss_weights=[0.8, 0.2],
              optimizer=optimizer, metrics=["accuracy"])

def add_random_bounding_boxes(images, labels):
    fake_bboxes = tf.random.uniform([tf.shape(images)[0], 4])
    return images, (labels, fake_bboxes)

fake_train_set = train_set.take(5).repeat(2).map(add_random_bounding_boxes)
model.fit(fake_train_set, steps_per_epoch=5, epochs=2)

Solution

  • Here is one way to achieve what you want. However, note that the dummy bounding box using tf.random.uniform makes less sense, by default the minval=0, maxval=1, so your dummy coordinates will give value within this range which is not appropriate for the bounding box and that's why in the following demonstration we will rescaling the coordinates with a scaler value (let's say with 150), and hopefully, you get the point.


    After training, preparing the test set for inference.

    import numpy as np
    import matplotlib.pyplot as plt
    
    print(class_names)
    test_set = test_set_raw.map(preprocess).batch(1).prefetch(1)
    test_set = test_set.map(add_random_bounding_boxes)
    ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']
    

    Display functionalities using matplotlib.

    for i, (X,y) in enumerate(test_set.take(1)):
        # true labels 
        true_label = y[0].numpy()
        true_bboxs = y[1].numpy()
    
        # model predicts 
        pred_label, pred_boxes = model.predict(X)
        pred_label = np.argmax(pred_label, axis=-1)
    
        # rescaling 
        dummy_true_boxes = (true_bboxs*150).astype(np.int32).clip(min=0, max=224)
        dummy_predict_boxes = (pred_boxes*150).astype(np.int32).clip(min=0, max=224)
    
        # Info printing 
        print('GT bbox scores: ', true_bboxs)
        print('PRED bbox scores: ', pred_boxes)
        print('After Rescaling and Clipped True BBOX: ', dummy_true_boxes)
        print('After Rescaling and Clipped Pred BBOX: ', dummy_predict_boxes)
        print('True label : {}, Predicted label {}'.format(class_names[int(true_label)], 
                                                           class_names[int(pred_label)]))
    
        plt.figure(figsize=(10, 10))
        plt.axis("off")
        plt.imshow(X[0])
        ax = plt.gca()
    
        for tbox, tcls, pbox, pcls in zip(dummy_true_boxes, true_label, dummy_predict_boxes, pred_label):
            # gt and pred labels 
            ttext = "GT: {}".format(class_names[tcls])
            ptext = "Pred: {}".format(class_names[pcls])
    
            # gt and pred co-ordinates 
            tx1, ty1, x2, y2 = tbox     # xmin, ymin, xmax, ymax
            tw, th = x2 - tx1, y2 - ty1  # width (w) = xmax - xmin ; height (h) = ymax - ymin
    
            px1, py1, x2, y2 = pbox    # xmin, ymin, xmax, ymax
            pw, ph = x2 - px1, y2 - py1  # width (w) = xmax - xmin ; height (h) = ymax - ymin
    
    
            patch = plt.Rectangle(
                [tx1, ty1], tw, th, fill=False, edgecolor=[0, 1, 0], linewidth=1
            )
            ax.add_patch(patch)
            ax.text(
                tx1,
                ty1,
                ttext,
                bbox={"facecolor": [1, 1, 1], "alpha": 0.5},
                clip_box=ax.clipbox,
                clip_on=True,
            )
    
            patch = plt.Rectangle(
                [px1, py1], pw, ph, fill=False, edgecolor=[1, 1, 1], linewidth=1
            )
            ax.add_patch(patch)
            ax.text(
                px1,
                py1,
                ptext,
                bbox={"facecolor": [1, 1, 1], "alpha": 0.5},
                clip_box=ax.clipbox,
                clip_on=True,
            )
        plt.show()
    
    GT bbox scores:  [[0.75246954 0.36959255 0.18266702 0.7125735 ]]
    PRED bbox scores:  [[1.1755341  0.98745024 0.90438926 1.285707  ]]
    After Rescaling and Clipped True BBOX:  [[112  55  27 106]]
    After Rescaling and Clipped Pred BBOX:  [[176 148 135 192]]
    True label : tulips, Predicted label sunflowers
    

    enter image description here