Search code examples
tensorflowtf.data.dataset

How to slice an existing tf.data dataset into elements of a new dataset


How to create a tf.data dataset out of an existing tf.data dataset whose elements consist of 2 four dimensional arrays? I have a dataset of images and corresponding segmentation masks. So I create a tf.dataset from the image and mask paths and apply some functions to the dataset for preprocessing. After this step the img and masks have shapes as [x,h,w,c] and [x,h,w,c]. So when using dataset.as_numpy_iterator() I get two arrays of these shapes. Now, I want to create a dataset whose element will be 2 arrays of shape [h,w,c] and [h,w,c] where each of the slice in the first dimension now becomes a separate element of the dataset. So if initially my dataset had 10 elements, it should now have 10 * x elements. But I am not able to slice out the elements from the existing dataset. This is what I have tried:

dataset = tf.data.Dataset.from_tensor_slices((imagepath, maskpath))
dataset = dataset.map(lambda imagepath, maskpath: tf.py_function(preprocessData, 
                                                inp=[imagepath, maskpath], 
                                                Tout=[tf.float64]*2))
datasetnew = tf.data.Dataset.from_tensor_slices(dataset)

Where the error I get is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_34/343231121.py in <module>
      3                                                 inp=[flairimg_val, msk_val],
      4                                                 Tout=[tf.float64]*2))
----> 5 datasetnew = tf.data.Dataset.from_tensor_slices(datasetval)
      6 # datasetval = datasetval.map(lambda flairimg_val, msk_val, path: get_2p5D_repre(flairimg_val, msk_val, path))
      7 # datasetval = datasetval.map(lambda flairimg_val, msk_val, path: try_return(flairimg_val, msk_val, path))

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/ops/dataset_ops.py in from_tensor_slices(tensors)
    758       Dataset: A `Dataset`.
    759     """
--> 760     return TensorSliceDataset(tensors)
    761 
    762   class _GeneratorState(object):

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, element)
   3320     element = structure.normalize_element(element)
   3321     batched_spec = structure.type_spec_from_value(element)
-> 3322     self._tensors = structure.to_batched_tensor_list(batched_spec, element)
   3323     self._structure = nest.map_structure(
   3324         lambda component_spec: component_spec._unbatch(), batched_spec)  # pylint: disable=protected-access

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/util/structure.py in to_batched_tensor_list(element_spec, element)
    362   # pylint: disable=protected-access
    363   # pylint: disable=g-long-lambda
--> 364   return _to_tensor_list_helper(
    365       lambda state, spec, component: state + spec._to_batched_tensor_list(
    366           component), element_spec, element)

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/util/structure.py in _to_tensor_list_helper(encode_fn, element_spec, element)
    337     return encode_fn(state, spec, component)
    338 
--> 339   return functools.reduce(
    340       reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), [])
    341 

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/util/structure.py in reduce_fn(state, value)
    335   def reduce_fn(state, value):
    336     spec, component = value
--> 337     return encode_fn(state, spec, component)
    338 
    339   return functools.reduce(

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/util/structure.py in <lambda>(state, spec, component)
    363   # pylint: disable=g-long-lambda
    364   return _to_tensor_list_helper(
--> 365       lambda state, spec, component: state + spec._to_batched_tensor_list(
    366           component), element_spec, element)
    367 

/usr/local/lib/python3.8/dist-packages/tensorflow/python/data/ops/dataset_ops.py in _to_batched_tensor_list(self, value)
   3492   def _to_batched_tensor_list(self, value):
   3493     if self._dataset_shape.ndims == 0:
-> 3494       raise ValueError("Unbatching a dataset is only supported for rank >= 1")
   3495     return self._to_tensor_list(value)
   3496 

ValueError: Unbatching a dataset is only supported for rank >= 1

Not sure what the rank part means here for a dataset? How to achieve this?


Solution

  • You're looking for the unbatch method

    Splits elements of a dataset into multiple elements.

    For example, if elements of the dataset are shaped [B, a0, a1, ...], where B may vary for each input element, then for each element in the dataset, the unbatched dataset will contain B consecutive elements of shape [a0, a1, ...].