I have loaded a dataset from tfds
where the data is in a NestedVariant
object. How do I extract the values from there?
E.g.:
import tensorflow_datasets as tfds
ds = tfds.load("rlu_control_suite", split="train")
for example in ds.take(1):
print(example["steps"])
Output:
<tensorflow.python.data.ops.dataset_ops._NestedVariant object at 0x2c3324940>
I can't find the docs on how to work with this object.
Use tf.data.Dataset.flat_map
to convert a nested dataset of datasets into a single dataset:
import tensorflow_datasets as tfds
ds = tfds.load("rlu_control_suite", split="train")
ds = ds.flat_map(lambda x: x['steps'])
for example in ds.take(1):
print(example)
{
'action': < tf.Tensor: shape = (1, ),
dtype = float32,
numpy = array([0.6250813], dtype = float32) > ,
'discount': < tf.Tensor: shape = (),
dtype = float32,
numpy = 1.0 > ,
'is_first': < tf.Tensor: shape = (),
dtype = bool,
numpy = True > ,
'is_last': < tf.Tensor: shape = (),
dtype = bool,
numpy = False > ,
'is_terminal': < tf.Tensor: shape = (),
dtype = bool,
numpy = False > ,
'observation': {
'position': < tf.Tensor: shape = (3, ),
dtype = float32,
numpy = array([-0.00654944, -0.99999475, 0.00324196], dtype = float32) > ,
'velocity': < tf.Tensor: shape = (2, ),
dtype = float32,
numpy = array([-0.0186821, 0.00950337], dtype = float32) >
},
'reward': < tf.Tensor: shape = (),
dtype = float32,
numpy = 1.6778985e-06 >
}