Search code examples
pythontensorflowtensorflow-datasets

How work with NestedVariant object in Tensorflow


I have loaded a dataset from tfds where the data is in a NestedVariant object. How do I extract the values from there?

E.g.:

import tensorflow_datasets as tfds
ds = tfds.load("rlu_control_suite", split="train") 
for example in ds.take(1):
     print(example["steps"])

Output:

<tensorflow.python.data.ops.dataset_ops._NestedVariant object at 0x2c3324940>

I can't find the docs on how to work with this object.


Solution

  • Use tf.data.Dataset.flat_map to convert a nested dataset of datasets into a single dataset:

    import tensorflow_datasets as tfds
    
    ds = tfds.load("rlu_control_suite", split="train") 
    ds = ds.flat_map(lambda x: x['steps'])
    for example in ds.take(1):
        print(example)
    
    {
        'action': < tf.Tensor: shape = (1, ),
        dtype = float32,
        numpy = array([0.6250813], dtype = float32) > ,
        'discount': < tf.Tensor: shape = (),
        dtype = float32,
        numpy = 1.0 > ,
        'is_first': < tf.Tensor: shape = (),
        dtype = bool,
        numpy = True > ,
        'is_last': < tf.Tensor: shape = (),
        dtype = bool,
        numpy = False > ,
        'is_terminal': < tf.Tensor: shape = (),
        dtype = bool,
        numpy = False > ,
        'observation': {
            'position': < tf.Tensor: shape = (3, ),
            dtype = float32,
            numpy = array([-0.00654944, -0.99999475, 0.00324196], dtype = float32) > ,
            'velocity': < tf.Tensor: shape = (2, ),
            dtype = float32,
            numpy = array([-0.0186821, 0.00950337], dtype = float32) >
        },
        'reward': < tf.Tensor: shape = (),
        dtype = float32,
        numpy = 1.6778985e-06 >
    }