Search code examples
pythontensorflowfiltersizetake

get the size of dataset after applying a filter from tf.data.Dataset


I wonder how I can get the size or the len of the dataset after applying a filter. Using tf.data.experimental.cardinality give -2, and this is not what I am looking for!! I want to know how many filtered samples exist in my dataset in order to be able to split it to training and validation datasets using take() and skip().

Example:

    dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5])
    dataset = dataset.filter(lambda x: x < 4)
    size = tf.data.experimental.cardinality(dataset).numpy()
    #size here is equal to -2 but I want to get the real size which is 3

My dataset contains images and their labels, this is just an illustrative example


Solution

  • Taking a look at the documentation reveals that a cardinality of -2 shows that Tensorflow is unable to determine the cardinality of the data set. You can find this in here. For your example, you can do

    dataset = dataset.as_numpy_iterator()
    dataset = list(dataset)
    print(len(dataset))