Search code examples
pythontensorflowkeraskeras-layer

Repeat keras LSTM output of all hidden states


Data:

Macro time-series with shape = (T, M) #same for all firms
Micro time-series with shape = (T, N, K) #firm-specific data

Where T is time-dimension in months, M the number of macro features, N the number of firms and K the number of micro features such as P/E ratio etc.

Task: use LSTM output, repeat it N times and concatenate with micro-data

I want to use the output of the LSTM, namely all (here H=4) hidden-states for all t in T for all firms, thus setting return_sequences=True, and repeat them N times in order to concatenate them with my micro-data, i.e. get new data with shape = (T, N, K+H)

This new data will then be reshaped to (T*N, K+H) and fed into a feed-forward neural network with a custom loss-function that applies to both neural nets and can by construction only be computed at time T, thus having batch-size=1!.

I have tried this the following way, but it doesn't work because of the dims:

from keras.layers import concatenate, Input, LSTM, RepeatVector
from keras.utils.vis_utils import plot_model
from keras.models        import Model

in_macro        = Input(shape = (T, M),         name = 'macro_input')
in_micro        = Input(shape = (T, N, K),      name = 'micro_input')

lstm            = LSTM(4, return_sequences=True)(in_macro)
rep             = RepeatVector(N)(lstm)
conc            = concatenate([in_micro, rep])
model           = Model(inputs=[in_micro,in_macro], outputs=conc)

plot_model(model, show_shapes=True)

ValueError: Input 0 is incompatible with layer repeat_vector_1: expected ndim=2, found ndim=3

Is there an efficient way to reach this goal, maybe even without the need to repeat?


Solution

  • Found a solution using Lambda layers and stack:

    from keras.layers import concatenate, Input, LSTM
    from keras.utils.vis_utils import plot_model
    from keras.models import Model
    import keras.backend as k
    
    in_macro = Input(shape=(T, M), name='macro_input')
    in_micro = Input(shape=(T, N, K), name='micro_input')
    
    lstm = LSTM(4, return_sequences=True)(in_macro)
    stack = Lambda(lambda x: k.stack(N*[x], axis=2))(lstm)
    conc = concatenate([in_micro, stack])
    model = Model(inputs=[in_micro,in_macro], outputs=conc)
    
    plot_model(model, show_shapes=True)
    

    enter image description here