When and How the Call function work in Model Subclassing of Keras?

I read in Hands-on Machine Learning with Scikit-Learn, Keras, and Tensorflow about using the Sub-classing API to build dynamic models which mainly involves writing a subclass with two methods in it: the constructor and a call function. The constructor is fairly easy to understand. However, I had problem understanding when and how the call function work exactly when the model is being built.

I used code from the book and experiment a bit as below(using California housing dataset from sklearn):

class WideAndDeepModel(keras.Model):
    def __init__(self, units=30, activation='relu', **kwargs):
        self.hidden1 = keras.layers.Dense(units, activation=activation)
        self.hidden2 = keras.layers.Dense(units, activation=activation)
        self.main_output = keras.layers.Dense(1)
        self.aux_output = keras.layers.Dense(1)
    def call(self, inputs):
        print('call function running')
        input_A, input_B = inputs
        hidden1 = self.hidden1(input_B)
        hidden2 = self.hidden2(hidden1)
        concat = keras.layers.concatenate([input_A, hidden2])
        main_output = self.main_output(concat)
        aux_output = self.aux_output(hidden2)
        return main_output, aux_output

model = WideAndDeepModel()
model.compile(loss=['mse','mse'], loss_weights=[0.9,0.1], optimizer='sgd')
history =[X_train_A, X_train_B],[y_train, y_train], epochs=20, validation_data=([X_val_A, X_val_B], [y_val, y_val]))

Below is the output during training:

Epoch 1/20
***call function running***
***call function running***
353/363 [============================>.] - ETA: 0s - loss: 1.6398 - output_1_loss: 1.5468 - output_2_loss: 2.4769
***call function running***
363/363 [==============================] - 1s 1ms/step - loss: 1.6224 - output_1_loss: 1.5296 - output_2_loss: 2.4571 - val_loss: 4.3588 - val_output_1_loss: 4.7174 - val_output_2_loss: 1.1308
Epoch 2/20
363/363 [==============================] - 0s 1ms/step - loss: 0.6073 - output_1_loss: 0.5492 - output_2_loss: 1.1297 - val_loss: 75.1126 - val_output_1_loss: 81.6632 - val_output_2_loss: 16.1572

The call function gets run twice at the beginning of training through the first epoch and then gets run almost at the end of the first epoch. It is never run after that.

This looks to me that while the layers are instantiated early in the constructor function, the connection between the layers(defined in the call function) are established quite late(at the start of the training). It also looks to me that there is no logic entities of this so called connection between the layers, the connection is just a process of passing the output of one layer to another in a specific order. Is my understanding correct?

The second question is why the call function gets run three times at the early stage of the training instead of just once.


    Also correct, the weights are initialized when you call or call the model for the first time, as you can see in this guide to subclass Keras layers:

    class Linear(keras.layers.Layer):
        def __init__(self, units=32):
            super(Linear, self).__init__()
            self.units = units
        def build(self, input_shape):
            self.w = self.add_weight(
                shape=(input_shape[-1], self.units),
            self.b = self.add_weight(
                shape=(self.units,), initializer="random_normal", trainable=True
        def call(self, inputs):
            return tf.matmul(inputs, self.w) + self.b

    why the call function gets run three times at the early stage

    Probably the first time is when the model was called for the first time, and the weights were instantiated. Then another time to build the Tensorflow graph, which is non-Python code than runs Tensorflow models. The model is called once to create this graph, and further calls are outside of Python so your print function is no longer part of it. You can change this behavior by using model.compile(..., run_eagerly=True). Finally, the third time would be the first time the validation data is passed.