Search code examples
kerassequentiallabel-encoding

How to override fit() and predict() in a Keras model


I've created a subclass of the keras.models.Sequential class, so that to override the fit() and predict() functions.

My goal is to 'hide' the a sklearn LabelEncoder. This way I can directly call fit() and predict() with a y array made up of arbitrary labels, without the requirement of them being integers in the range [0, 1, ..., num_classes - 1].

Implementation example:

import numpy as np

from keras.models import Sequential
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder

class SuperSequential(Sequential):

  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.encoder = LabelEncoder()

  def fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> Sequential:
    y_enc = self.encoder.fit_transform(y)
    y_enc = to_categorical(y_enc, len(np.unique(y_enc)))

    return super().fit(X, y_enc)

  def predict(self, X: np.ndarray) -> np.ndarray:
    y_pred = super().predict(X)
    y_pred = np.argmax(y_pred , axis=1)

    return self.label_encoder.inverse_transform(y_pred)

Unfortunately, this isn't very convenient for my use case. I'd like to save a trained model using keras.models.save_model() and then load everything via keras.models.load_model(). However, the loaded model is always of the base Sequential class, which does not include the overridden fit() and predict().

UPDATE: If I load the model passing the appropriate custom_objects field (as shown below), the loaded object does have the expected type (SuperSequential), but the LabelEncoder isn't 'fitted'.

keras.models.load_model("model_path", custom_objects={"SuperSequential": SuperSequential})

I've also found that Keras allows the use of pre-processing layers such as keras.layers.IntegerLookup, which seems to do what I want, but it isn't clear to me how to use it as part of a Sequential model for label encoding.

My questions are:

  1. How can I save and load a subclass of Sequential, if at all possible?
  2. How could I achieve the same goal as with the sub-class but with a Keras pre-processing label such as keras.layers.IntegerLookup?
  3. The lack of supporting references for what I'm trying to do leads me to think that this doesn't make sense. Is there a better method to achieve my goals?

Solution

  • I'll answer my own question, as it may help someone in the future.

    The 'overriding' approach was not the correct one.

    The label encoding and decoding steps are pre- and post-processing steps. As such, they should not be 'shoehorned' in the fit() and predict() methods, but rather be added as additional layers in the Sequential model.

    This keeps concerns separated and doesn't hide the pre- and post-processing steps, as they'll be visible when one inspects a loaded model via tf.keras.Model.summary(), for example.

    I ended following a two step approach:

    1. Training: I create a label encoder object that takes care of encoding the original labels into a 'one-hot-encoded' 2D array. I've used a keras.layers.IntegerLookup object to accomplish this. I then pass the original labels to this label encoder, and simply fit() the model with the encoded labels.
    2. Inference: After training the model, I create an 'inference' version of the model (perhaps a better term should be 'pipeline' instead of 'model'), by adding two post-processing layers to it: (a) a custom argmax layer that extracts the encoded labels with highest probability; and (b) a label decoding layer (also based a keras.layers.IntegerLookup object) that essentially does the opposite of the pre-processing object I've used in step 1.

    After step 2, I can save the 'inference' version of the model using keras.models.save_model(), which includes the post-processing layers. When load() is called, all I have to do is call predict(), which directly provides me an array with the predicted class labels, in their original format.

    In order to implement the argmax layer, I had to implement a custom Keras layer, as shown in the example at the bottom.

    For reference, here's a concrete example:

    import numpy as np
    import tensorflow as tf
    
    from keras.models import Sequential
    from keras.datasets import mnist
    from keras import layers
    
    
    class ArgMax(tf.keras.layers.Layer):
        """
        Custom Keras layer that extracts the labels from 
        an array of probabilities per label.
        """
        def __init__(self):
            super(ArgMax, self).__init__()
    
        def call(self, inputs):
            return tf.math.argmax(inputs, axis=1)
    
    
    def load_dataset(discard:list=[]):
        """
        Loads mnist dataset, filters out unwanted labels and re-shapes arrays.
        """
        (X_tr, y_tr), (X_val, y_val) = mnist.load_data()
    
        X_tr = X_tr[~np.isin(y_tr, discard),:]
        y_tr = y_tr[~np.isin(y_tr, discard)]
    
        X_val = X_val[~np.isin(y_val, discard),:]
        y_val = y_val[~np.isin(y_val, discard)]
    
        NUM_ROWS = X_tr.shape[1]
        NUM_COLS = X_tr.shape[2]
    
        X_tr = X_tr.reshape((X_tr.shape[0], NUM_ROWS * NUM_COLS))
        X_val = X_val.reshape((X_val.shape[0], NUM_ROWS * NUM_COLS))
    
        X_tr = X_tr.astype('float32') / 255
        X_val = X_val.astype('float32') / 255
    
        return (X_tr, y_tr), (X_val, y_val)
    
    
    if __name__ == "__main__":
        # load dataset : discard some of the labels 
        # to test correct operation of pre- and post-processing layers
        (X_tr, y_tr), (X_val, y_val) = load_dataset(discard=[1, 3, 5])
    
        # label pre-processing
        label_preprocessing = layers.IntegerLookup(
            output_mode="one_hot", 
            num_oov_indices=0
        )
        label_preprocessing.adapt(y_tr)
        print(f"vocabulary : {label_preprocessing.get_vocabulary()}")
        print(f"vocabulary size : {len(label_preprocessing.get_vocabulary())}")
    
        # label post-processing 
        label_postprocessing = layers.IntegerLookup(
            num_oov_indices=0,
            invert=True
        )
        label_postprocessing.adapt(y_tr)
        print(f"vocabulary : {label_postprocessing.get_vocabulary()}")
        print(f"vocabulary size : {len(label_postprocessing.get_vocabulary())}")
    
        # create model using Sequential API
        model = Sequential()
        model.add(tf.keras.layers.Dense(512, activation='relu', input_shape=(X_tr.shape[1],)))
        model.add(tf.keras.layers.Dropout(0.5))
        model.add(tf.keras.layers.Dense(256, activation='relu'))
        model.add(tf.keras.layers.Dropout(0.25))
        model.add(tf.keras.layers.Dense(len(np.unique(y_tr)), activation='softmax'))
    
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
    
        # fit the model using the pre-processed labels
        model.fit(X_tr, label_preprocessing(y_tr),
              batch_size=128,
              epochs=10,
              verbose=1,
              validation_data=(X_val, label_preprocessing(y_val)))
    
        # create model for inference, i.e., with 2 post-processing layers:
        #   - add a layer that does argmax() operation
        #   - add a layer to invert the integer labels
        model.add(ArgMax())
        model.add(label_postprocessing)
    
        # save the model
        model.save('inference_model')
    
        # load the model
        loaded_model = tf.keras.models.load_model('inference_model')
    
        # compare the first 20 predictions of the loaded model to the ground truth
        print(loaded_model.predict(X_val[:20]))
        print(y_val[:20])