Search code examples
pythontensorflowmachine-learningkerasloss-function

Advanced Machine Learning in Python: Handling Class Imbalance in Multi-class Classification with Custom Loss Function


I am working on a multi-class classification problem in Python using advanced machine learning techniques. The dataset I am dealing with has a significant class imbalance issue, where some classes are underrepresented compared to others. This imbalance is adversely affecting the performance of my model, particularly for the minority classes.

To address this, I am considering the implementation of a custom loss function that can better handle class imbalance. I am using TensorFlow/Keras for model development. My current model structure is as follows:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Example model architecture
model = Sequential([
    Dense(128, activation='relu', input_shape=(input_shape,)),
    Dense(64, activation='relu'),
    Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

Here, num_classes represents the number of classes in my dataset, and input_shape is the shape of the input features. The problem is with the loss function 'categorical_crossentropy', which does not account for the class imbalance.

I am looking for a way to create a custom loss function that can integrate the class weights into the computation, thereby giving more importance to the minority classes during training. Here are my specific questions:

How can I develop a custom loss function in TensorFlow/Keras that incorporates class weights for a multi-class classification problem? What are the best practices to ensure that this custom loss function is computationally efficient and does not negatively impact the training time significantly? Are there any potential pitfalls or common mistakes I should be aware of when implementing a custom loss function for handling class imbalance?


Solution

  • Fortunately, keras comes with a built-in functionality to weight your data when calculating the loss, so no custom function is needed.

    Since you haven't pasted any code regarding your input data, I am assuming you are using tf.data.Dataset, as this is the recommended method to load your data. According to this SO post, we can simply return a third value using tf.data.Dataset, which will be used as the sample weight. Below you can find a fully reproducible example which uses your model definition. To see the effect of weight, simply comment/uncomment it in line 44.

    PS: If you would like to learn more about downsampling/upsampling and how you should weigh your data, here's some useful documentation by Google about it.

    import random
    
    import pandas as pd
    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense
    
    #fix seeds
    tf.keras.utils.set_random_seed(
        42
    )
    random.seed(42)
    
    # column definitions
    LABEL_COLUMN = 'label'
    WEIGHT_COLUMN = 'weight'
    NUMERIC_COLS = ['col_1', 'col_2']
    LABELS = [0, 1, 2]
    
    # generate some example data
    col_1 = [i for i in range(1,101)]
    col_2 = [i for i in range(1,101)]
    col_3 = [random.choice([0, 1, 2]) for i in range(1, 101)]
    col_4 = [random.choice([1, 3, 1]) for i in range(1, 101)]
    
    data = {
        'col_1': col_1,
        'col_2': col_2,
        'label': col_3,
        'weight': col_4
    }
    
    df = pd.DataFrame(data)
    
    # create tf.data.Dataset
    def prep_data(row_data):
    
        _label = row_data.pop(LABEL_COLUMN)
        weight = row_data.pop(WEIGHT_COLUMN)
    
        label = tf.one_hot(_label, len(LABELS))
    
        # return row_data.values(), label, weight
        return row_data.values(), label
    
    
    ds = tf.data.Dataset.from_tensor_slices(dict(df))
    ds = ds.map(map_func=prep_data)
    ds = ds.batch(16)
    
    # create model
    # Example model architecture
    model = Sequential([
        Dense(128, activation='relu', input_shape=(len(NUMERIC_COLS),)),
        Dense(64, activation='relu'),
        Dense(len(LABELS), activation='softmax')
    ])
    
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # fit model
    model.fit(
        ds, epochs=1, verbose=1
    )
    
    # using weight
    # 7/7 [==============================] - 4s 6ms/step - loss: 4.6744 - accuracy: 0.3500
    # weight commented out
    # 7/7 [==============================] - 1s 5ms/step - loss: 2.7758 - accuracy: 0.3300