Search code examples
pythontensorflowkeras

Can't pass multiple callbacks to a keras model


I'm building a keras LSTM model and on first passes I see that it's been overfitting the data a bit, so I initialised 2 callbacks - one to control variable learning rate and the other to allow for early stopping:

    def _initialise_callback(self):

        # Ensure learning rate decreases with the epoch number
        learning_rate = 0.1
        decay_rate = learning_rate / self.epochs
        momentum = 0.8
        self.sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)

        #Allow model to stop early to prevent overfitting
        self.early_stopping = EarlyStopping(monitor='loss', patience=3)

But then for some reason I can't seem to pass them both to the fit() method. What I do is:

    def fit(self):
        self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
                       callbacks=[self.early_stopping, self.sgd],
                       use_multiprocessing=False)

and this results in the following error:


  File "<ipython-input-1-1532e4234d2a>", line 1, in <module>
    runfile('C:/VULCAN_HOME/sampling_bias/bias_LSTM.py', wdir='C:/VULCAN_HOME/sampling_bias')

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
    execfile(filename, namespace)

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 174, in <module>
    predictor.fit()

  File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 164, in fit
    use_multiprocessing=False)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1147, in fit
    initial_epoch=initial_epoch)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
    initial_epoch=initial_epoch)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 100, in fit_generator
    callbacks.set_model(callback_model)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 68, in set_model
    callback.set_model(model)

AttributeError: 'SGD' object has no attribute 'set_model'

On the other hand, if I try to pass only sgd or only early_stopping then all works fine. Anyone knows what's happening here?


Solution

  • SGD optimizer should be passed as a parameter to the compile method as shown here and not as a callback parameter to the fit method. I have modified your code below:

    def fit(self):
            self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
                           callbacks=[self.early_stopping],
                           use_multiprocessing=False)
    

    And when you compile the model pass your optimizer

    self.model.compile(optimizer=self.sgd, **kwargs)