Search code examples
machine-learningdeep-learningnlpartificial-intelligenceloss-function

How to use cosine similarity within triplet loss


The triplet loss is defined as follows:

L(A, P, N) = max(‖f(A) - f(P)‖² - ‖f(A) - f(N)‖² + margin, 0)

where A=anchor, P=positive, and N=negative are the data samples in the loss, and margin is the minimum distance between the anchor and positive/negative samples.

I read somewhere that (1 - cosine_similarity) may be used instead of the L2 distance.

Note that I am using Tensorflow - and the cosine similarity loss is defined that When it is a negative number between -1 and 0, 0 indicates orthogonality and values closer to -1 indicate greater similarity. The values closer to 1 indicate greater dissimilarity. So, it is the opposite of cosine similarity metric.

Any suggestions on how to write my triplet loss with cosine similarity?

Edit

All good stuff in the answers (comments and answers). Based on all the hints - this is working ok for me:

 self.margin = 1
 self.loss = tf.keras.losses.CosineSimilarity(axis=1)
 ap_distance = self.loss(anchor, positive)
 an_distance = self.loss(anchor, negative)
 loss = tf.maximum(ap_distance - an_distance + self.margin, 0.0)

I would like to eventually use the tensorflow addon loss as @pygeek pointed out but I haven't figured out how to pass the data yet.

Note To use it standalone - one must do something like this:

cosine_similarity = tf.keras.metrics.CosineSimilarity()
cosine_similarity.reset_state()
cosine_similarity.update_state(anch_prediction, other_prediction)
similarity = cosine_similarity.result().numpy() 

Resources

pytorch cosine embedding layer

tensorflow cosine similarity implmentation

tensorflow triplet loss hard/soft margin


Solution

  • First of all, Cosine_distance = 1 - cosine_similarity. The distance and similarity are different. This is not correctly mentioned in some of the answers!

    Secondly, you should look at the TensorFlow code on how the cosine similarity loss is implemented https://github.com/keras-team/keras/blob/v2.9.0/keras/losses.py#L2202-L2272, which is different from PyTorch!!

    Finally, I suggest you use existing loss: You should replace the || ... ||^2 with tf.losses.cosineDistance(...).