I am trying to use load_weights
function from tf.keras.Model
but it does not seem to work properly.
If I call model.load_weights(weights_path, by_name=True, skip_mismatch=True)
I get an error for mismatch for shape, which is precisely what I expected the "skip_mismatch"
argument to take care of.
This code snippet is a relatively simple mnist
dataset case that reproduces my error. I run this in google colab and the thing that happens is just like in my own code.
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
#import mnist dataset in x,y fashion
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
#build a simple model with sequential dense layers
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Rescaling(1./255, input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
#compile the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#create checkpoint directory
os.makedirs("checkpoints", exist_ok=True)
cp_callback = tf.keras.callbacks.ModelCheckpoint("checkpoints/cp-{epoch:04d}.ckpt",
save_weights_only=True,
)
#train the model
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), callbacks=[cp_callback])
model2 = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Rescaling(1./255, input_shape=(28,28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(5)
])
#compile model2
model2.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#load weights of model into model2
model2.load_weights("checkpoints/cp-0010.ckpt",by_name=True, skip_mismatch=True)
I get
"Received incompatible tensor with shape (10,) when attempting to restore variable with shape (5,) and name dense_5/bias:0." whether I use
skip_mismatch=True
orskip_mismatch=False
.
Am I using the function wrong? What is the proper way to use it?
If you look at the docs of load_weights and save_weights, it should get clearer.
From load_weights
:
Weight loading by name
If your weights are saved as a .h5 file created via model.save_weights(), you can use the argument by_name=True.
In this case, weights are loaded into layers only if they share the same name. This is useful for fine-tuning or transfer-learning models where some of the layers have changed.
And from save_weights
:
save_format Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise, None becomes 'tf'. Defaults to None.
From ModelCheckpoint
:
save_weights_only if True, then only the model's weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).
Seems like you don't save in the .h5 format, and therefore can not use by_name=True
. ModelCheckpoint just calls model.save_weights(path)
without save_format
, and your savepath ending is neither .keras
nor .h5
.
Note: For now I can't test this solution, and it is just citing the docs. I'll come back later and test this, if necessary.