Search code examples
pythonkerastime-serieslstmmulticlass-classification

Saving LSTM hidden states while training and predicting for multi-class time series classification


I am trying to use an LSTM for multi-class classification of time series data.

The training set has dimensions (390, 179), i.e. 390 objects with 179 time steps each.

There are 37 possible classes.

I would like to use a Keras model with just an LSTM and activation layer to classify input data.

I also need the hidden states for all the training data and test data passed through the model, at every step of the LSTM (not just the final state).

I know return_sequences=True is needed, but I'm having trouble getting dimensions to match.

Below is some code I've tried, but I've tried a ton of other combinations of calls from a motley of stackexchange and git issues. In all of them I get some dimension mismatch or another.

I don't know how to extract the hidden state representations from the model.

We have X_train.shape = (390, 1, 179), Y_train.shape = (390, 37) (one-shot binary vectors)/.

n_units = 8
n_sequence = 179
n_class = 37

x = Input(shape=(1, n_sequence))
y = LSTM(n_units, return_sequences=True)(x)
z = Dense(n_class, activation='softmax')(y)

model = Model(inputs=[x], outputs=[y])
model.compile(loss='categorical_crossentropy', optimizer='adam')

model.fit(X_train, Y_train, epochs=100, batch_size=128)
Y_test_predict = model.predict(X_test, batch_size=128)

This is what the above gives me:

ValueError: A target array with shape (390, 37) was passed for an output of shape (None, 1, 37) while using as loss 'categorical_crossentropy'. This loss expects targets to have the same shape as the output.


Solution

  • There didn't seem to be any way to build a working trainable model while also returning the hidden states with return_sequences=True.

    The fix I found was to build a predictor model and train it, and save the weights. Then I built a new model which ended with my LSTM layer, and fed it the trained weights. So, using return_sequences=True, I was able to predict on new data and get the data's representations at each hidden state.