I am trying to code a K-fold cross validation with LSTM architecture.
But I got an this error (edit):
Traceback (most recent call last):
File "/Users/me/Desktop/dynamicsoundtreatments/DST-features-RNN.py", line 58, in <module>
model.fit(training_data, training_label, epochs=100, batch_size=nbr_de_son)
File "/Users/me/miniforge3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/me/miniforge3/lib/python3.9/site-packages/keras/engine/training.py", line 1395, in fit
raise ValueError('Unexpected result of `train_function` '
ValueError: Unexpected result of `train_function` (Empty logs). Please use `Model.compile(..., run_eagerly=True)`, or `tf.config.run_functions_eagerly(True)` for more information of where went wrong, or file a issue/bug to `tf.keras`.
I tried to add run_eagerly=True
but got the same error again.
I tried few alternatives such as def train(training_data, training_label): model.fit(training_data, training_label, epochs=100, batch_size=nbr_de_son)
outside of the for loop. Got same error.
I was wondering if I should use Functional API, but I am very new to datascience. I really don't why I got this error. Thanks for your help.
nbr_anal = int(6)
nbr_de_son = int(samples.shape[0]/nbr_anal)
sequence = int(samples.shape[1])
input_shape=(nbr_anal, sequence)
# ------------------------------------------------------------------------
# PREPROCESSING
# batch size, sequence length, features
samples = samples.reshape(nbr_de_son, nbr_anal, sequence)
labels_extrait = np.argmax(labels_extrait, axis=1)
print(labels_extrait.shape)
# ------------------------------------------------------------------------
# K-Fold
k = 4
num_validation_samples = len(samples) // k
num_validation_labels = len(labels_extrait) // k
validation_scores = []
model = Sequential()
model.add(LSTM(sequence,input_shape=input_shape))
model.add(Dropout(0.3))
model.add(Dense(8, activation='softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics='sparse_categorical_accuracy', run_eagerly=True)
for fold in range(k):
rng_state = np.random.get_state()
np.random.shuffle(samples)
np.random.set_state(rng_state)
np.random.shuffle(labels_extrait)
validation_data = samples[num_validation_samples * fold:num_validation_samples * (fold + 1)]
print(validation_data.shape)
validation_label = labels_extrait[num_validation_labels * fold:num_validation_labels * (fold + 1)]
print(validation_label.shape)
training_data = samples[:num_validation_samples * (fold + 1)] + samples[num_validation_samples * (fold)]
training_label = labels_extrait[:num_validation_labels * (fold + 1)] + labels_extrait[num_validation_labels * (fold)]
model.fit(training_data, training_label, epochs=100, batch_size=nbr_de_son)
validation_score = evaluate(validation_data, validation_label)
validation_scores.append(validation_score)
validation_score = np.average(validation_scores)
print(validation_score)
You can use StratifiedKFold
from the sklearn package to do the cross validation. It is much clearer and is the standard way to do it. You should also reset the model weights at each fold before fitting the model, otherwise you will start with the weights initialized in the previous call of the fit method,
The modified code :
from sklearn.model_selection import StratifiedKFold
# K-Fold cross validation
k = 4
skf = StratifiedKFold(n_splits=k)
validation_scores = []
# store initial model's weights
weights_init = model.get_weights()
for train_index, test_index in skf.split(samples, labels_extrait):
training_data = samples[train_index]
training_label = labels_extrait[train_index]
validation_data = samples[test_index]
validation_label = labels_extrait[test_index]
# reset mdoel's weights
model.set_weights(weights_init)
# fit
model.fit(training_data, training_label, epochs=100, batch_size=nbr_de_son)
validation_score = model.evaluate(validation_data, validation_label)
validation_scores.append(validation_score)
validation_score = np.average(validation_scores)
print(validation_score)
I don't know exactly where your error comes from, but the above code works