Search code examples
tensorflowkerastensorflow-datasets

Why does the same Tensorflow model work with a list of arrays but doesn't work with tf.data.Dataset unbatched?


I have the following simple set up:

import tensorflow as tf

def make_mymodel():
  class MyModel(tf.keras.models.Model):

    def __init__(self):
      super(MyModel, self).__init__()

      self.model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(1, 2)),
        tf.keras.layers.Dense(1)
      ])

    def call(self, x):
      return self.model(x)

  mymodel = MyModel()

  return mymodel


model = make_mymodel()

X = [[[1, 1]],
     [[2, 2]],
     [[10, 10]],
     [[20, 20]],
     [[50, 50]]]
y = [1, 2, 10, 20, 50]

# ds_n_X = tf.data.Dataset.from_tensor_slices(X)
# ds_n_Y = tf.data.Dataset.from_tensor_slices(y)
# ds = tf.data.Dataset.zip((ds_n_X, ds_n_Y))
#
# for input, label in ds:
#   print(input.numpy(), label.numpy())

loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

model.build((1, 2))
model.compile(optimizer='adam',
              loss=loss_fn)
model.summary()

model.fit(X, y, epochs=10)

print(model.predict([
  [[25, 25]]
]))

This works fine (although I get strange predictions), but when I uncomment the ds lines and change model.fit(X, y, epochs=10) to model.fit(ds, epochs=10), I get the following error:

Traceback (most recent call last):
  File "example_dataset.py", line 51, in <module>
    model.fit(ds, epochs=10)
  ...


    ValueError: slice index 0 of dimension 0 out of bounds. for '{{node strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](Shape, strided_slice/stack, strided_slice/stack_1, strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.

The error gets solved when I run model.fit(ds.batch(2), epochs=10) (I added a batch instruction to the dataset).

I expect to be able to use a list of arrays and tf.data.Dataset interchangeably but, for some reason, I need to add a batch dimension to the dataset in order to use tf.data.Dataset. Is this expected behavior or am I conceptually missing something?


Solution

  • Because the model expects input as (batch_dim, input_dim). So, for your data, each input to the model should be like (None, 1, 2).

    Let's explore the dimensions of your data by array and by dataset. While you define your input as array the shape is:

    >>> print(np.array(X).shape)
    (5, 1, 2)
    

    It is compatible with what the model expects. But when you define a dataset using your array the shape is:

    >>> for input, label in ds.take(1):
            print(input.numpy().shape)
    (1, 2)
    

    And this is incompatible with what model expects, and if we batch the data:

    >>> ds = ds.batch(1)
    >>> for input, label in ds.take(1):
            print(input.numpy().shape)
    (1, 1, 2)
    

    Then, it will be fine to pass dataset to the model.fit().