Search code examples
machine-learningkeraslstmfederated-learning

LSTM sequence prediction overfits on one specific value only


hello guys i am new in machine learning. I am implementing federated learning on with LSTM to predict the next label in a sequence. my sequence looks like this [2,3,5,1,4,2,5,7]. for example, the intention is predict the 7 in this sequence. So I tried a simple federated learning with keras. I used this approach for another model(Not LSTM) and it worked for me, but here it always overfits on 2. it always predict 2 for any input. I made the input data so balance, means there are almost equal number for each label in last index (here is 7).I tested this data on simple deep learning and greatly works. so it seems to me this data mybe is not suitable for LSTM or any other issue. Please help me. This is my Code for my federated learning. Please let me know if more information is needed, I really need it. Thanks

def get_lstm(units):
    """LSTM(Long Short-Term Memory)
    Build LSTM Model.

    # Arguments
        units: List(int), number of input, output and hidden units.
    # Returns
        model: Model, nn model.
    """
    model = Sequential()
    inp = layers.Input((units[0],1))
    x = layers.LSTM(units[1], return_sequences=True)(inp)
    x = layers.LSTM(units[2])(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(units[3], activation='softmax')(x)

    model = Model(inp, out)



 optimizer = keras.optimizers.Adam(lr=0.01)

seqLen=8 -1;
global_model = Mymodel.get_lstm([seqLen, 64, 64, 15]) # 14 categories we have , array start from 0 but never can predict zero class
global_model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)) 

def main(argv): 



 for comm_round in range(comms_round):
            print("round_%d" %( comm_round))
            scaled_local_weight_list = list()
            global_weights = global_model.get_weights()
            np.random.shuffle(train) 
            temp_data = train[:]
            
            # data divided among ten users and shuffled
            for user in range(10):
                user_data = temp_data[user * userDataSize: (user+1)*userDataSize]

                X_train = user_data[:, 0:seqLen]
                X_train = np.asarray(X_train).astype(np.float32)
                Y_train = user_data[:, seqLen]    
                Y_train = np.asarray(Y_train).astype(np.float32)
                local_model = Mymodel.get_lstm([seqLen, 64, 64, 15])
                X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
                                                         
                local_model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1))
                local_model.set_weights(global_weights)
                

                local_model.fit(X_train, Y_train)
                scaling_factor = 1 / 10 # 10 is number of users
                scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
                scaled_local_weight_list.append(scaled_weights)
                K.clear_session()

            average_weights = sum_scaled_weights(scaled_local_weight_list)
            global_model.set_weights(average_weights)


predictions=global_model.predict(X_test)
for i in range(len(X_test)):
    print('%d,%d' % ((np.argmax(predictions[i])), Y_test[i]),file=f2 )

Solution

  • I could find some reasons for my problem, so I thought I can share it with you:
    1- the proportion of different items in sequences are not balanced. I mean for example I have 1000 of "2" and 100 of other numbers, so after a few rounds the model fitted on 2 because there are much more data for specific numbers.
    2- I changed my sequences as there are not any two items in a sequence while both have same value. so I could remove some repetitive data from the sequences and make them more balance. maybe it is not the whole presentation of activities but in my case it makes sense.