Let's try to make MobileNet V. 2
locate a bright band on a noisy image. Yes, it is overkill to use a deep convolutional network for such a tack, but originally it was intended just like a smoke test to make sure the model works. We will train it on synthetic data:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
SHAPE = (32, 320, 1)
def gen_sample():
while True:
data = np.random.normal(0, 1, SHAPE)
i = np.random.randint(0, SHAPE[1]-8)
data[:,i:i+8,:] += 4
yield data.astype(np.float32), np.float32(i)
ds = tf.data.Dataset.from_generator(gen_sample, output_signature=(
tf.TensorSpec(shape=SHAPE, dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.float32))).batch(100)
d, i = next(gen_sample())
plt.figure()
plt.imshow(d)
plt.show()
Now we build and train a model:
model = tf.keras.models.Sequential([
tf.keras.applications.MobileNetV2(
input_shape=SHAPE, include_top=False, weights=None, alpha=0.5),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=0.01, decay_steps=1000, decay_rate=0.9)),
loss='mean_squared_error')
history = model.fit(ds, steps_per_epoch=10, epochs=40)
We use generated data, so we don't need a validation set, do we? So we can just watch how the loss decreases. And it does decrease decently well:
Epoch 1/40
10/10 [==============================] - 27s 2s/step - loss: 15054.8417
Epoch 2/40
10/10 [==============================] - 23s 2s/step - loss: 193.9126
Epoch 3/40
10/10 [==============================] - 24s 2s/step - loss: 76.9586
Epoch 4/40
10/10 [==============================] - 25s 2s/step - loss: 68.8521
...
Epoch 37/40
10/10 [==============================] - 20s 2s/step - loss: 4.5258
Epoch 38/40
10/10 [==============================] - 20s 2s/step - loss: 22.1212
Epoch 39/40
10/10 [==============================] - 20s 2s/step - loss: 28.4854
Epoch 40/40
10/10 [==============================] - 20s 2s/step - loss: 18.0123
Training happened to stop not at the best result, but it still should be reasonable: the answers should be around the true value ±8. Let's test it:
d, i = list(ds.take(1))[0]
model.evaluate(d, i)
np.stack((model.predict(d).ravel(), i.numpy()), 1)[:10,]
4/4 [==============================] - 0s 32ms/step - loss: 16955.7871
array([[ 66.84666 , 222. ],
[ 66.846664, 46. ],
[ 66.846664, 71. ],
[ 66.84668 , 268. ],
[ 66.846664, 86. ],
[ 66.84668 , 121. ],
[ 66.846664, 301. ],
[ 66.84667 , 106. ],
[ 66.84665 , 138. ],
[ 66.84667 , 95. ]], dtype=float32)
Wow! Where does this huge evaluation loss come from? And why the model keeps predicting the same stupid value? Everything was so good during the training!
Actually, in a day or so I realized what was going on, but I offer to others a possibility to solve this charade and earn some points.
The problem was that a network reasonably functioning in the training mode failed to work in the inference mode. What might be the cause? There are basically two layer types working differently in the two modes: dropout and batch normalization. In MobileNet V. 2
, we have only batch normalization, so let's consider how it works.
In the training mode a BN layer calculates batch mean and variance and normalizes the data using these batch values. At the same time it remembers the mean and variance as a moving average weighted with a coefficient called momentum
.
moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
moving_var = moving_var * momentum + var(batch) * (1 - momentum)
Indeed, this momentum
is an important hyperparameter, especially if true batch statistics are far from the initial values. Suppose the initial variance value is 1.0
, the momentum is 0.99
(which is the default), and the true data variance is 0.1
. Than the 10% error (var < 0.11
) can be achieved after 447 batches.
Now the root cause of the problem: in MobileNet
all the numerous BN layers have momentum=0.999
, which means it will take 4497 batch steps to achieve the same 10% error! When you are training on a very large heterogeneous data set like ImageNet in small batches, this is a 100% reasonable hyperparameter choice. But in this toy example the result is that the BN layers just fail to remember true data statistics during 400 batches and use completely wrong values during inference!
And the fix is very simple: just change the momenta before model.compile
:
for layer in model.layers[0].layers:
if type(layer) is tf.keras.layers.BatchNormalization:
layer.momentum = 0.9