Search code examples
tensorflowdeep-learninglstmrecurrent-neural-network

Build a multimodal LSTM


I have the following LSTM network. I want to add the red line in this figure to the model. this is the model I want it Here is my model:

import numpy as np
import tensorflow as tf
from keras.models import Sequential, Model,load_model
from keras.layers import Dense, Dropout, Activation, LSTM,  Input, concatenate
from keras.utils import np_utils
from sklearn.metrics import mean_squared_error
import keras
from keras_self_attention import SeqSelfAttention, SeqWeightedAttention
X1 = np.random.normal(size=(100,1,2))
X2 = np.random.normal(size=(100,1,2))
X3 = np.random.normal(size=(100,1,2))
Y = np.random.normal(size=(100,18))

input_1  = Input(shape=(X1.shape[1], X1.shape[2]),  name='input_1')
input_2  = Input(shape=(X2.shape[1], X2.shape[2]),  name='input_2')
input_3  = Input(shape=(X3.shape[1], X3.shape[2]),  name='input_3')
# lstms
lstm1  = LSTM(200, name='lstm1')(input_1)
lstm2  = LSTM(200, name='lstm2')(input_2)
lstm3  = LSTM(200, name='lstm3')(input_3)
## outputs
output1  = Dense(18, activation="linear", name='out1')(lstm1)
output2  = Dense(18, activation="linear", name='out2')(lstm2)
output3  = Dense(18, activation="linear", name='out3')(lstm3)
concat = concatenate([lstm1, lstm2, lstm3])
output = Dense(18, activation="linear", name='out1')(concat)
model = Model(inputs=[input_1, input_2, input_3], outputs=output)
model.compile(optimizer = 'adam', loss = 'mean_squared_error',metrics = ['MAE'])
model.fit([X1, X2, X3], Y, epochs =1, batch_size = 100)

Can anybody help me to build this model? thanks


Solution

  • Try with return_state=True in the LSTM layer. It allows you to get the last h and c computed by the LSTM. So you can use them in the following LSTM as initial_state :

    lstm1,h1,c1  = LSTM(200, name='lstm1',return_state=True)(input_1)
    lstm2,h2,c2  = LSTM(200, name='lstm2',return_state=True)(input_2,initial_state=[h1,c1])
    lstm3  = LSTM(200, name='lstm3')(input_3,initial_state=[h2,c2])
    

    gives you (not well displayed...) :

    enter image description here