I have a tf2 Dataset API dataset
, which undergoes multiple map
operations followed by tf.image.resize
that constantly outputs shape (300, 300)
i.e. each record is guaranteed to have this shape after all map operations. However, this is not inherently inferred, and hence the Tensor Spec shows <undefined>, <undefined>
shape. Undefined shaped datasets throw an error if they are passed to a model with pre-defined input shape.
Some searching helped me find this function tf.contrib.data.assert_element_shape and Issue #16052:
dataset = dataset.apply(tf.data.experimental.assert_element_shape(custom_shape))
But this function has been removed in tf2, and the docs does NOT recommend using something else in place of assert_element_shape. What is it's equivalent? Or how do I assign shape to a dataset that is guaranteed to output certain shape?
For some reason, adding set_shape
within the map function where I added tf.image.resize
does NOT work.
# does not work
def my_map_function(image, label):
# some image operations here
image = tf.image.resize(image, size=[300, 300])
image.set_shape((300, 300, 3))
return image, label
But when I made a separate map function, it works:
# works
def set_shapes(image, label):
image.set_shape((300, 300, 3))
label.set_shape([])
return image, label
Perhaps I'll stick to this until a direct assert_element_shape
or set_element_shape
gets added as separate functions