Search code examples
reinforcement-learningdqnkeras-rl

Dueling DQN updates model architecture and causes issues


I create an initial network model with the following acrchitecture.

def create_model(env):

    dropout_prob = 0.8 #aggresive dropout regularization
    num_units = 256 #number of neurons in the hidden units
  
    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.input_shape))
    model.add(Dense(num_units))
    model.add(Activation('relu'))

    model.add(Dense(num_units))
    model.add(Dropout(dropout_prob))
    model.add(Activation('relu'))

    model.add(Dense(env.action_size))
    model.add(Activation('softmax'))
    print(model.summary())
    return model

Then I call the DQNAgent that updates the network architecture

dqn = DQNAgent(model=model, nb_actions=env.action_size, memory=memory,
               nb_steps_warmup=settings['train']['warm_up'], 
               target_model_update=settings['train']['update_rate'], policy=policy, enable_dueling_network=True)
dqn.compile(Adam(lr=settings['train']['learning_rate']), metrics=['mse'])

Doing this results in an updated network architecture - as expected. The issue now is that when I try to call this fitted new network, the original create model function can't accept the saved model weights because the layer architecture doesn't fit at all.

print(model.summary())

Layer (type)                 Output Shape              Param #   
=================================================================
flatten_49 (Flatten)         (None, 106)               0         
_________________________________________________________________
dense_147 (Dense)            (None, 128)               13696     
_________________________________________________________________
activation_145 (Activation)  (None, 128)               0         
_________________________________________________________________
dense_148 (Dense)            (None, 64)                8256      
_________________________________________________________________
dropout_49 (Dropout)         (None, 64)                0         
_________________________________________________________________
activation_146 (Activation)  (None, 64)                0         
_________________________________________________________________
dense_149 (Dense)            (None, 3)                 195       
_________________________________________________________________
activation_147 (Activation)  (None, 3)                 0         
=================================================================
Total params: 22,147
Trainable params: 22,147
Non-trainable params: 0

print(dqn.model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_49_input (InputLayer (None, 1, 1, 106)         0         
_________________________________________________________________
flatten_49 (Flatten)         (None, 106)               0         
_________________________________________________________________
dense_147 (Dense)            (None, 128)               13696     
_________________________________________________________________
activation_145 (Activation)  (None, 128)               0         
_________________________________________________________________
dense_148 (Dense)            (None, 64)                8256      
_________________________________________________________________
dropout_49 (Dropout)         (None, 64)                0         
_________________________________________________________________
activation_146 (Activation)  (None, 64)                0         
_________________________________________________________________
dense_149 (Dense)            (None, 3)                 195       
_________________________________________________________________
dense_150 (Dense)            (None, 4)                 16        
_________________________________________________________________
lambda_3 (Lambda)            (None, 3)                 0         
=================================================================
Total params: 22,163
Trainable params: 22,163
Non-trainable params: 0
_________________________________________________________________

So without training a new dqn, I need to find a way to create a network architecture that is created off the original architecture but applies dqn model changes.


Solution

  • Best way is to save the entire model with

    dqn.model.save('xyz')
    

    and then load the model rather than just the weights with in the new function rather than creating the original model and converting it to duelling form.

    model = keras.models.load_model('xyz')