Search code examples
pythonkerasdeep-learningmultilabel-classificationimbalanced-data

How to process "strong" imbalaced data for multi-label image classification with transfer learning


I tried myself but couldn't reach the final point that's why posting here, please guide me. I'm working in multi-label image classification and have slightly different scenarios. I have a big and significant imbalance dataset. You can see the dataset here.

  • This dataset has significant imbalanced and I don't know how to efficient processing it (class 98 and 99 isn't used in my scenarios): count_plot

value_counts

And here is model code on which I am working:

def load_network(labels):
    cnn = EfficientNetB0(include_top=False, input_shape=(224, 224, 3) ,
                         weights="imagenet")
    # Rebuild top
    x = layers.GlobalAveragePooling2D(name="avg_pool")(cnn.output)
    x = layers.BatchNormalization()(x)
    flatten = tf.keras.layers.Flatten()(x)
    fcn = tf.keras.layers.Dense(2048, activation='relu')(flatten)
    # fcn_1 = tf.keras.layers.Dense(1024, activation='relu')(fcn)
    fcn_classification = tf.keras.layers.Dense(len(labels), activation='sigmoid')(fcn)
    model = tf.keras.Model(inputs=cnn.inputs, outputs=fcn_classification)

    for layer in model.layers[:20]:
        layer.trainable = False
    for layer in model.layers[20:]:
        layer.trainable = True
    # model.summary()
    return model

train_data, valid_data, test_data = split(new_data)
train_generator = DataGenerator(train_data, labels, os.path.join(PATH, 'images'), num_classes=len(labels),
                                batch_size=BATCH_SIZE)
valid_generator = DataGenerator(valid_data, labels, os.path.join(PATH, 'images'), num_classes=len(labels),
                                batch_size=BATCH_SIZE)
test_generator = DataGenerator(test_data, labels, os.path.join(PATH, 'images'), num_classes=len(labels),
                               batch_size=BATCH_SIZE)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy', min_delta=0, patience=5, verbose=1,
    mode='auto', baseline=None, restore_best_weights=True
)
model = load_network(labels)
model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=tf.keras.optimizers.Adam(learning_rate=0.00002), metrics=['accuracy'])
history = model.fit(train_generator, epochs=EPOCHS, shuffle=True, validation_data=valid_generator,
                    callbacks=[early_stopping], verbose=1) 
test_results = model.evaluate(test_generator)

After that, my_model has

loss: 0.1171 - accuracy: 0.4098 - val_loss: 0.2396 - val_accuracy: 0.2229

and this result was so bad when predicts. My goal is to have one image input and return some Vienna class for this image. (if you don't know Vienna classification. You can read it here). How can I process this imbalanced data and improve my model? Could I get more layers for transfer learning?


Solution

  • Balancing datasets for training is always tricky, and only a few approaches can help.

    1. Get more data to balance out the dataset. This is obvious but also unreasonable in many cases.

    2. Reduce the size of the larger classes. This works sometimes but you're also left with a very small set of images sometimes.

    3. Adding class weights to try and balance the dataset. I.E. the weight is inversely proportional to the number of images of that class, meaning images from the lesser class will impact the model more than images from the greater class.

    4. Remove the classes with images that are too low, on group them as "other". Kind of a workaround, doesn't solve the problem, but often used in business settings.

    5. Living with it. If the data proportioned accurately reflect how it is in real life (i.e. identifying humans and cars, many more humans than there are cars) and your test scenario is also filled with images of that proportion, it might not even be that disadvantageous to leave the imbalance in. Your model will identify humans much more accurately than cars, but it's fine as the bias works in your favor.

    Personally I try to use 2 and 3 whenever possible to see if it can increase my model accuracy. Otherwise sometimes 4/5 is the best for business cases.