Search code examples
pythontensorflowkerasdeep-learningkeras-layer

Removing auxiliary branch in Keras model at inference time


I am trying my hand at implementing a forecasting algorithm that combines both LSTM and CNN models from this paper. Essentially, the paper proposed a model with three branches: a CNN branch, an LSTM branch, and a merged branch that combines both. The first two branches are only present during training to prevent overfitting and ensure the final model is trained for both CNN and LSTM features. Here is the diagram in the paper (alpha, beta, and gamma in the total loss function are just weight for those particular losses.) lstm-cnn model structure As I understand it, these are similar to the auxiliary branches in the likes of ResNet and Inception model to ensure that every layer is contributing to the model output. I implemented this accordingly:

def construct_lstm_cnn(look_forward, look_back=30):
    cnn = construct_cnn(look_forward, fc=False)
    cnn_flatten = Flatten()(cnn.output)
    lstm = construct_lstm(look_forward, look_back, 2, fc=False)

    #Merged layer (the main branch that will be making prediction after training)
    cnn_lstm = concatenate([cnn_flatten, lstm.output])
    fc_merged    = Dense(500, activation='relu')(cnn_lstm)
    drop_merged  = Dropout(0.5)(fc_merged)
    fc2_merged   = Dense(100, activation='relu')(drop_merged)
    drop2_merged = Dropout(0.5)(fc2_merged)
    fc3_merged   = Dense(25 , activation='relu')(drop2_merged)
    drop3_merged = Dropout(0.5)(fc3_merged)
    pred_merged  = Dense(look_forward, activation='linear')(drop3_merged)

    #Auxiliary branch for cnn (want to remove at inference time)
    fc_cnn    = Dense(500, activation='relu')(cnn_flatten)
    drop_cnn  = Dropout(0.5)(fc_cnn)
    fc2_cnn   = Dense(100, activation='relu')(drop_cnn)
    drop2_cnn = Dropout(0.5)(fc2_cnn)
    fc3_cnn   = Dense(25 , activation='relu')(drop2_cnn)
    drop3_cnn = Dropout(0.5)(fc3_cnn)
    pred_cnn_aux  = Dense(look_forward, activation='linear')(drop3_cnn)

    #Auxiliary branch for lstm (want to remove at inference time)
    fc_lstm    = Dense(500, activation='relu')(lstm.output)
    drop_lstm  = Dropout(0.5)(fc_lstm)
    fc2_lstm   = Dense(100, activation='relu')(drop_lstm)
    drop2_lstm = Dropout(0.5)(fc2_lstm)
    fc3_lstm   = Dense(25 , activation='relu')(drop2_lstm)
    drop3_lstm = Dropout(0.5)(fc3_lstm)
    pred_lstm_aux  = Dense(look_forward, activation='linear')(drop3_lstm)

    #Final model with three branches
    model = Model(inputs=[cnn.input, lstm.input], outputs=[pred_merged, pred_cnn_aux, pred_lstm_aux],    name="lstm-cnn")
    return model

However, I can't seem to find a way in Keras to remove the listed auxiliary branches. Is there a way I could remove the layers that are not useful during inference time?


Solution

  • I provide you a simplified example

    here the full model with all the branches... this is the model to fit

    def construct_lstm_cnn():
    
        inp_lstm = Input((20,30))
        lstm = LSTM(32, activation='relu')(inp_lstm)
        inp_cnn = Input((32,32,3))
        cnn = Conv2D(32, 3, activation='relu')(inp_cnn)
        cnn = Flatten()(cnn)
    
        cnn_lstm = Concatenate()([cnn, lstm])
        cnn_lstm = Dense(1)(cnn_lstm)
    
        fc_cnn = Dense(32, activation='relu')(cnn)
        fc_cnn = Dropout(0.5)(fc_cnn)
        fc_cnn = Dense(1)(fc_cnn)
    
        fc_lstm = Dense(32, activation='relu')(lstm)
        fc_lstm = Dropout(0.5)(fc_lstm)
        fc_lstm = Dense(1)(fc_lstm)
    
        model = Model(inputs=[inp_cnn, inp_lstm], outputs=[cnn_lstm, fc_cnn, fc_lstm])
        return model
    
    lstm_cnn = construct_lstm_cnn()
    lstm_cnn.compile(...)
    lstm_cnn.summary()
    
    lstm_cnn.fit(...)
    
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_10 (InputLayer)           [(None, 32, 32, 3)]  0                                            
    __________________________________________________________________________________________________
    conv2d_18 (Conv2D)              (None, 30, 30, 32)   896         input_10[0][0]                   
    __________________________________________________________________________________________________
    input_9 (InputLayer)            [(None, 20, 30)]     0                                            
    __________________________________________________________________________________________________
    flatten_3 (Flatten)             (None, 28800)        0           conv2d_18[0][0]                  
    __________________________________________________________________________________________________
    lstm_5 (LSTM)                   (None, 32)           8064        input_9[0][0]                    
    __________________________________________________________________________________________________
    dense_13 (Dense)                (None, 32)           921632      flatten_3[0][0]                  
    __________________________________________________________________________________________________
    dense_15 (Dense)                (None, 32)           1056        lstm_5[0][0]                     
    __________________________________________________________________________________________________
    concatenate_1 (Concatenate)     (None, 28832)        0           flatten_3[0][0]                  
                                                                     lstm_5[0][0]                     
    __________________________________________________________________________________________________
    dropout_3 (Dropout)             (None, 32)           0           dense_13[0][0]                   
    __________________________________________________________________________________________________
    dropout_4 (Dropout)             (None, 32)           0           dense_15[0][0]                   
    __________________________________________________________________________________________________
    dense_12 (Dense)                (None, 1)            28833       concatenate_1[0][0]              
    __________________________________________________________________________________________________
    dense_14 (Dense)                (None, 1)            33          dropout_3[0][0]                  
    __________________________________________________________________________________________________
    dense_16 (Dense)                (None, 1)            33          dropout_4[0][0]                  
    ==================================================================================================
    

    for inference time, after training, we can simply remove the unuseful branches in this way

    lstm_cnn_inference = Model(lstm_cnn.input, lstm_cnn.output[0])
    lstm_cnn_inference.summary()
    
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_10 (InputLayer)           [(None, 32, 32, 3)]  0                                            
    __________________________________________________________________________________________________
    conv2d_18 (Conv2D)              (None, 30, 30, 32)   896         input_10[0][0]                   
    __________________________________________________________________________________________________
    input_9 (InputLayer)            [(None, 20, 30)]     0                                            
    __________________________________________________________________________________________________
    flatten_3 (Flatten)             (None, 28800)        0           conv2d_18[0][0]                  
    __________________________________________________________________________________________________
    lstm_5 (LSTM)                   (None, 32)           8064        input_9[0][0]                    
    __________________________________________________________________________________________________
    concatenate_1 (Concatenate)     (None, 28832)        0           flatten_3[0][0]                  
                                                                     lstm_5[0][0]                     
    __________________________________________________________________________________________________
    dense_12 (Dense)                (None, 1)            28833       concatenate_1[0][0]              
    ==================================================================================================
    

    in this way we maintain only the central branch