Search code examples
tensorflowindicestensorargmax

How to check if the argmax of a tensor is equal to any argmax of another tensor which has several equal max?


So usually in single label classification, we use the following

correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))

But I am working with multi label classification so I'd like to know how to do that where there are several ones in the label vector. So what I have so far is below

 a = tf.constant([0.2, 0.4, 0.3, 0.1])
 b = tf.constant([0,1.0,1,0])
 empty_tensor = tf.zeros([0])
 for index in range(b.get_shape()[0]):
     empty_tensor = tf.cond(tf.equal(b[index],tf.constant(1, dtype = 
     tf.float32)), lambda:  tf.concat([empty_tensor,tf.constant([index], 
     dtype = tf.float32)], axis = 0), lambda: empty_tensor)

 temp, _ = tf.setdiff1d([tf.argmax(a)], tf.cast(empty_tensor, dtype= tf.int64))
 output, _ = tf.setdiff1d([tf.argmax(a)], tf.cast(temp, dtype = tf.int64))

So this gives me the indice at which max(preds) happens and where there is a 1 in self.label. In the above example it gives [1] and if the argmax do not match, then I get [].

The issue that I have is that I do not how to proceed from there as I would like something like the following

correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1))
self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))

which is straightforward for single label classification.

Thanks a lot


Solution

  • I don't think you can achieve this with softmax so I am assuming that you are using sigmoids for your preds. If you are using sigmoids, your outputs will be each (independently) be between 0 and 1. You can define a threshold for each, perhaps 0.5, and then convert your sigmoid preds into the label encoding (0's and 1's) by doing preds > 0.5.

    If prediction is [0 1] and label is [1 1], do you want to report that as completely or partially wrong? I am going to assume the former. In that case, you would remove the tf.argmax call and instead check if the preds and label are exactly the same vectors, which would look like tf.reduce_all(tf.equal(preds, label), axis=0). For the latter, the code would look like tf.reduce_sum(tf.equal(preds, label), axis=0).