Search code examples
pythontensorflowone-hot-encoding

How to count in a one hot tensor


I know I can convert a Tensor to the one-hot using this command:

one_hot_labels = tf.one_hot(labels,depth=3)

Now I want to count how many of class 0, class 1, and class 2 are there in the one_hot_labels. What is the easiest way to count that?

Example:

Input:

one_hot_labels = [[1,0,0],[1,0,0],[0,0,1]]
one_hot_labels.count([1,0,0]) # something like this command

Output:

2

Solution

  • Something like this should work for you:

    one_hot_labels = np.array([[1,0,0],[1,0,0],[0,0,1]])
    count_label = tf.reduce_sum(one_hot_labels, axis=0)
    sess = tf.Session()
    sess.run(count_label)
    # array([2, 0, 1])
    

    Now for example you can just do:

    count_label = tf.reduce_sum(one_hot_labels, axis=0)[0]
    # 2