Search code examples
tensorflowmachine-learningdeep-learningtensorflow-datasets

How to acquire tf.data.dataset's shape?


I know dataset has output_shapes, but it shows like below:

data_set: DatasetV1Adapter shapes: {item_id_hist: (?, ?), tags: (?, ?), client_platform: (?,), entrance: (?,), item_id: (?,), lable: (?,), mode: (?,), time: (?,), user_id: (?,)}, types: {item_id_hist: tf.int64, tags: tf.int64, client_platform: tf.string, entrance: tf.string, item_id: tf.int64, lable: tf.int64, mode: tf.int64, time: tf.int64, user_id: tf.int64}

How can I get the total number of my data?


Solution

  • Where the length is known you can call:

    tf.data.experimental.cardinality(dataset)
    

    but if this fails then, it's important to know that a TensorFlow Dataset is (in general) lazily evaluated so this means that in the general case we may need to iterate over every record before we can find the length of the dataset.

    For example, assuming you have eager execution enabled and its a small 'toy' dataset that fits comfortably in memory you could just enumerate it into a new list and grab the last index (then add 1 because lists are zero-indexed):

    dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1
    

    Of course this is inefficient at best and, for large datasets, will fail entirely because everything needs to fit into memory for the list. in such circumstances I can't see any alternative other than to iterate through the records keeping a manual count.