tf.data.Dataset.window
returns a new dataset, whose elements are datasets, and elements of those nested datasets are windows of the desired size. If you have a dataset (say, Dataset.range(10)
and want a dataset of windows like [0 1 2] [1 2 3] ... [7 8 9]
), there's a trick to do that with window
plus flat_map
:
>>> d = tf.data.Dataset.range(10).window(3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(3))
>>> print(list(d))
[<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 3, 4])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 4, 5])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([5, 6, 7])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([6, 7, 8])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([7, 8, 9])>]
However, the flat_map
causes the dataset to lose cardinality information:
>>> d.cardinality.numpy()
<tf.Tensor: shape=(), dtype=int64, numpy=-2>
(-2 is UNKNOWN_CARDINALITY; see Tensorflow 2.0: flat_map() to flatten Dataset of Dataset returns cardinality -2)
I would like to create a dataset of such windows, while retaining the cardinality information. One slight annoyance from working with datasets of unknown cardinality is that Keras training progress bars need to run on one epoch first before they can produce an ETA. I tried .take(n_windows)
where I calculate n_windows
myself, but that still returned a dataset with UNKNOWN_CARDINALITY
.
Is there some way to window a dataset without losing cardinality information?
The main issue is that cardinality is computed statically. Therefore the cardinality of a flat_map
operation can not be computed. You can refer to this issue
The solution, as you know the relation of the flat_map
inputs and output, is to set the cardinality yourself using tf.data.experimental.assert_cardinality
.
This is an example on how to set back the window cardinality:
import tensorflow as tf
ds = tf.data.Dataset.range(10)
print("Original cardinality -> ", ds.cardinality().numpy())
# Output:
# Original cardinality -> 10
ds = ds.window(3, shift=1, drop_remainder=True)
# cardinality at this point is still known.
# as drop_remainder is true, window cardinality will be <= original cardinality
window_cardinality = ds.cardinality()
print("window cardinality ->",window_cardinality.numpy())
# Output:
# window cardinality -> 8
ds = ds.flat_map(lambda x: x.batch(3))
# after flat_map the inferred cardinality is lost.
print("flat cardinality ->",ds.cardinality().numpy())
# Output:
# flat cardinality -> -2
# as we know the flat_map relation is 1:1 we can set the cardinality back to the original value.
ds = ds.apply(tf.data.experimental.assert_cardinality(window_cardinality))
print("dataset cardinality ->",ds.cardinality().numpy())
print("length of dataset ->", len(list(ds)))
# Output:
# dataset cardinality -> 8
# length of dataset -> 8
for idx, x in ds.enumerate():
print(f"{idx} -> {x}")
# Output:
# 0 -> [0 1 2]
# 1 -> [1 2 3]
# 2 -> [2 3 4]
# 3 -> [3 4 5]
# 4 -> [4 5 6]
# 5 -> [5 6 7]
# 6 -> [6 7 8]
# 7 -> [7 8 9]