I am passing in sample_weight as the 3rd tuple in tf.data.Dataset (using it in the context of mask, so my sample_weight are either 0, or 1. The problem is that this sample_weight doesn't seem to get applied to metrics calculation. (Ref: https://www.tensorflow.org/guide/keras/train_and_evaluate#sample_weights)
Here's code snippet:
train_ds = tf.data.Dataset.from_tensor_slices((imgs, labels, masks))
train_ds = train_ds.shuffle(1024).repeat().batch(32).prefetch(buffer_size=AUTO)
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
metrics = ['sparse_categorical_accuracy'])
model.fit(train_ds, steps_per_epoch = len(imgs)//32, epochs = 20)
The loss after training is very close to zero, but sparse_categorical_accuracy is not (about 0.89). So I highly suspect whatever sample_weight (masks) that's passed in to construct the tf.dataset, does NOT get applied when the metrics is reported during training, while loss seems to be correct. I further confirmed by running prediction on the subset that are not masked separately, and confirmed the accuracy is 1.0
Also, according to documentation:
the metric has 3 args: y_true, y_pred, sample_weight
So how does one pass the sample_weight during metric computation? Is this the responsibility of model.fit(...) within the keras framework? I can't find any example googling around so far.
Upon some debugging and doc reading, i found there's weighted_metrics argument in .compile, which i should use instead of metrics=. I confirmed this fixed my test case in the shared colab.
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
weighted_metrics = [SparseCategoricalAccuracy()])