Search code examples
pythontensorflowtensorflow-datasets

tf.data.dataset: How do I assign shape to a dataset (with shape undefined) that is guaranteed to output certain shape?


I have a tf2 Dataset API dataset, which undergoes multiple map operations followed by tf.image.resize that constantly outputs shape (300, 300) i.e. each record is guaranteed to have this shape after all map operations. However, this is not inherently inferred, and hence the Tensor Spec shows <undefined>, <undefined> shape. Undefined shaped datasets throw an error if they are passed to a model with pre-defined input shape.

Some searching helped me find this function tf.contrib.data.assert_element_shape and Issue #16052:

dataset = dataset.apply(tf.data.experimental.assert_element_shape(custom_shape))

But this function has been removed in tf2, and the docs does NOT recommend using something else in place of assert_element_shape. What is it's equivalent? Or how do I assign shape to a dataset that is guaranteed to output certain shape?


Solution

  • For some reason, adding set_shape within the map function where I added tf.image.resize does NOT work.

    # does not work
    def my_map_function(image, label):
        # some image operations here
        image = tf.image.resize(image, size=[300, 300])
        image.set_shape((300, 300, 3))
        return image, label
    

    But when I made a separate map function, it works:

    # works
    def set_shapes(image, label):
        image.set_shape((300, 300, 3))
        label.set_shape([])
        return image, label
    

    Perhaps I'll stick to this until a direct assert_element_shape or set_element_shape gets added as separate functions