I am trying to do some custom calculations in the custom loss function. But when I log the statements from the custom loss function, it seems that custom loss function is only called once (in the begin of .fit() method).
Example of Loss function:
def loss(y_true, y_pred):
print("--- Starting of the loss function ---")
print(y_true)
loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
print("--- Ending of the loss function ---")
return loss
Using callback to check when the batch starts and ends:
class monitor(Callback):
def on_batch_begin(self, batch, logs=None):
print("\n >> Starting a new batch (batch index) :: ", batch)
def on_batch_end(self, batch, logs=None):
print(">> Ending a batch (batch index) :: ", batch)
.fit() method used as:
history = model.fit(
x=[inputs],
y=[outputs],
shuffle=False,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCH,
verbose=1,
callbacks=[monitor()]
)
Parameters used:
BATCH_SIZE = 128
NUM_EPOCH = 3
inputs.shape = (512, 8)
outputs.shape = (512, 2)
And the output:
Epoch 1/3
>> Starting a new batch (batch index) :: 0
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
--- Starting of the loss function ---
Tensor("IteratorGetNext:5", shape=(128, 2), dtype=float32)
--- Ending of the loss function ---
1/4 [======>.......................] - ETA: 0s - loss: 0.5551
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 5ms/step - loss: 0.5307
Epoch 2/3
>> Starting a new batch (batch index) :: 0
1/4 [======>.......................] - ETA: 0s - loss: 0.5443
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 5ms/step - loss: 0.5246
Epoch 3/3
>> Starting a new batch (batch index) :: 0
1/4 [======>.......................] - ETA: 0s - loss: 0.5433
>> Ending a batch (batch index) :: 0
>> Starting a new batch (batch index) :: 1
>> Ending a batch (batch index) :: 1
>> Starting a new batch (batch index) :: 2
>> Ending a batch (batch index) :: 2
>> Starting a new batch (batch index) :: 3
>> Ending a batch (batch index) :: 3
4/4 [==============================] - 0s 4ms/step - loss: 0.5219
Why the custom loss function is only called in the starting and its not called for every batch calculations? And I would also like to know when the loss function is called/triggered?
The loss function debug messages were printed only at the beginning of the training.
This is because internally your loss function got converted into tensorflow graph for the sake of performance, and the python print function only works when your function is being traced. i.e. it printed only at the beginning of the training which implies your loss function was being traced at that time. Please refer to the following page for more information: https://www.tensorflow.org/guide/function
Short answer: To print properly, use tf.print() instead of print()
And I would also like to know when the loss function is called/triggered?
After you use tf.print(), the debug messages will be printed properly. You will see your loss function is called at least once per step for getting the loss value and thus the gradient.