Search code examples
tensorflowtensorflow2.0tensorflow-datasets

How to get samples per class for TensorFlow Dataset


I am using dataset from TensorFlow datasets. Is there an easy way to access number of samples for each class in dataset? I was searching through keras api, and I did not found any ready to use function.

Ultimately I would like to plot a bar plot with number of samples at Y axis, and int indicating class id at X axis. The goal is to show how evenly is data distributed across classes.


Solution

  • With np.fromiter you can create a 1-D array from an iterable object.

    import tensorflow_datasets as tfds
    import numpy as np
    import seaborn as sns
    
    dataset = tfds.load('cifar10', split='train', as_supervised=True)
    
    labels, counts = np.unique(np.fromiter(dataset.map(lambda x, y: y), np.int32), 
                           return_counts=True)
    
    plt.ylabel('Counts')
    plt.xlabel('Labels')
    sns.barplot(x = labels, y = counts) 
    

    enter image description here


    Update: You can also count the labels like below:

    labels = []
    for x, y in dataset:
      # Not one hot encoded
      labels.append(y.numpy())
    
      # If one hot encoded, then apply argmax
      # labels.append(np.argmax(y, axis = -1))
    labels = np.concatenate(labels, axis = 0) # Assuming dataset was batched.
    

    Then you can plot them using the labels array.