I'm building a keras LSTM model and on first passes I see that it's been overfitting the data a bit, so I initialised 2 callbacks - one to control variable learning rate and the other to allow for early stopping:
def _initialise_callback(self):
# Ensure learning rate decreases with the epoch number
learning_rate = 0.1
decay_rate = learning_rate / self.epochs
momentum = 0.8
self.sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)
#Allow model to stop early to prevent overfitting
self.early_stopping = EarlyStopping(monitor='loss', patience=3)
But then for some reason I can't seem to pass them both to the fit()
method. What I do is:
def fit(self):
self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
callbacks=[self.early_stopping, self.sgd],
use_multiprocessing=False)
and this results in the following error:
File "<ipython-input-1-1532e4234d2a>", line 1, in <module>
runfile('C:/VULCAN_HOME/sampling_bias/bias_LSTM.py', wdir='C:/VULCAN_HOME/sampling_bias')
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
execfile(filename, namespace)
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 174, in <module>
predictor.fit()
File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 164, in fit
use_multiprocessing=False)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1147, in fit
initial_epoch=initial_epoch)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
initial_epoch=initial_epoch)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 100, in fit_generator
callbacks.set_model(callback_model)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 68, in set_model
callback.set_model(model)
AttributeError: 'SGD' object has no attribute 'set_model'
On the other hand, if I try to pass only sgd
or only early_stopping
then all works fine. Anyone knows what's happening here?
SGD
optimizer should be passed as a parameter to the compile
method as shown here and not as a callback parameter to the fit method. I have modified your code below:
def fit(self):
self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
callbacks=[self.early_stopping],
use_multiprocessing=False)
And when you compile the model pass your optimizer
self.model.compile(optimizer=self.sgd, **kwargs)