Search code examples
pythontensorflowtensorflow-datasets

Tensorflow dataset how to get the shape of the generator of data?


Consider loading the following dataset from tensorflow datasets

(ds_train, ds_test), ds_info= tfds.load('mnist', split=['train', 'test'],
                                        shuffle_files=True,
                                        as_supervised=True,with_info=True)

However, the website said

#https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator
#Warning: SOME ARGUMENTS ARE DEPRECATED: (output_shapes, output_types). They will be removed in a future version. 
#Instructions for updating: Use output_signature instead

but none of the

ds_train.output_shapes
ds_train.output_types
ds_train.output_signature

were working

A similar issue was mentioned here #https://github.com/tensorflow/datasets/issues/102 , so right now only the temporary fix

shape_of_data=tf.compat.v1.data.get_output_shapes(ds_train)

was working, which returned

(TensorShape([None, 28, 28, 1]), TensorShape([None]))

Another updated function was working, but one could not get the TensorShape out of the argument

tf.data.DatasetSpec(ds_train) 

returned

DatasetSpec(<_OptionsDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>, TensorShape([]))

which could not be assigned.

What's the updated function or attributes to get the shape of the generator/iterator?


Solution

  • One can use dataset.element_spec:

    import tensorflow_datasets as tfds
    
    (ds_train, ds_test), ds_info = tfds.load(
        "mnist",
        split=["train", "test"],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    
    ds_train.element_spec
    # (TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None),
    #  TensorSpec(shape=(), dtype=tf.int64, name=None))
    
    ds_train.element_spec[0].shape
    # TensorShape([28, 28, 1])