Search code examples
neural-networkrankingloss-functionevaluationinformation-retrieval

Is it possible to use evaluation metrics (like NDCG) as a loss function?


I am working on a Information Retrieval model called DPR which is a basically a neural network (2 BERTs) that ranks document, given a query. Currently, This model is trained in binary manners (documents are whether related or not related) and uses Negative Log Likelihood (NLL) loss. I want to change this binary behavior and create a model that can handle graded relevance (like 3 grades: relevant, somehow relevant, not relevant). I have to change the loss function because currently, I can only assign 1 positive target for each query (DPR uses pytorch NLLLoss) and this is not what I need.

I was wondering if I could use a evaluation metric like NDCG (Normalized Discounted Cumulative Gain) to calculate the loss. I mean, the whole point of a loss function is to tell how off our prediction is and NDCG is doing the same.

So, can I use such metrics in place of loss function with some modifications? In case of NDCG, I think something like subtracting the result from 1 (1 - NDCG_score) might be a good loss function. Is that true?

With best regards, Ali.


Solution

  • Yes, this is possible. You would want to apply a listwise learning to rank approach instead of the more standard pairwise loss function.

    In pairwise loss, the network is provided with example pairs (rel, non-rel) and the ground-truth label is a binary one (say 1 if the first among the pair is relevant, and 0 otherwise).

    In the listwise learning approach, however, during training you would provide a list instead of a pair and the ground-truth value (still a binary) would indicate if this permutation is indeed the optimal one, e.g. the one which maximizes nDCG. In a listwise approach, the ranking objective is thus transformed into a classification of the permutations.

    For more details, refer to this paper.

    Obviously, the network instead of taking features as input may take BERT vectors of queries and the documents within a list, similar to ColBERT. Unlike ColBERT, where you feed in vectors from 2 docs (pairwise training), for listwise training u need to feed in vectors from say 5 documents.