Search code examples
kerasdeep-learningneural-networklstmrecurrent-neural-network

keras, LSTM - predict on inputs of different length?


I have fitted an LSTM that deals with inputs of different length:

model = Sequential()
model.add(LSTM(units=10, return_sequences=False, input_shape=(None, 5)))
model.add(Dense(units=1, activation='sigmoid'))

Having fitted the model, I want to test it on inputs of different size.

x_test.shape # = 100000
x_test[0].shape # = (1, 5)
x_test[1].shape # = (3, 5)
x_test[2].shape # = (8, 5)

Testing on single instances j is not a problem (model.predict(x_test[j]), but looping on all of them is really slow.

Is there a way of speeding up the computation? model.predict(x_test) does not work.
Thank you!


Solution

  • The best solution that I have found so far is grouping together data windows with the same length. For my problem, it's enough to significantly speed up the computation.

    Hope this trick would help other people.

    import numpy as np
    
    def predict_custom(model, x):
        """x should be a list of np.arrays with different number of rows, but same number of columns"""
        
        # dictionary with key = length of the window, value = indices of samples with such length
        dic = {}
        for i, x in enumerate(x):
            if dic.get(x.shape[0]):
                dic[x.shape[0]].append(i)
            else:
                dic[x.shape[0]] = [i]
        
        y_pred = np.full((len(x),1), np.nan)
        
        # loop over dictionary and predict together samples of the same length
        for key, indexes in dic.items():
            # select samples of the same length (conversion to np.array is used for subsetting "x" using "indexes")
            x = np.asarray(x, dtype=object)[indexes].tolist()
            # gather such samples in a 3D np.array
            x_3d = np.stack(x, axis=0)
            # use dictionary values to insert results in the correspondent row of y_pred
            y_pred[indexes] = model.predict(x_3d)
            
        return y_pred