Search code examples
pythontensorflowtensorflow-datasets

how to use tensorflow Dataset.from_generator


I'm trying to fit a model using a dataset built using 'tf.data.Dataset.from_generator'. But the fit fails.

here the code of dataset:

cd_gen=CordicDatasetFT(14)
cos=(tf.TensorSpec(shape=(14, 3), dtype=tf.float32, name=None),
     tf.TensorSpec(shape=(14, 3), dtype=tf.float32, name=None))
cds = tf.data.Dataset.from_generator(cd_gen, output_signature = cos)

It seems that it is ready to train my model:

print(type(cds))
cds_tst=cds.batch(512)

for batch_it in cds_tst:
    x, y = batch_it
    y_pre=model.predict(x)
    print(y_pre.shape)
    print("step")
    break
<class 'tensorflow.python.data.ops.flat_map_op._FlatMapDataset'>
[CordicDatasetFT]: call
16/16 [==============================] - 0s 7ms/step
(512, 14, 3)
step

But if I try to fit:

history=model.fit(cds, epochs=1)

I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 4
      1 #model.fit(ds, validation_data=ds, batch_size=512, epochs=20, steps_per_epoch=256, validation_steps=32)
      2 #history=model.fit(cds, batch_size=512, epochs=75, steps_per_epoch=256)
      3 print(type(cds))
----> 4 history=model.fit(cds, epochs=1)

File /opt/conda/lib/python3.10/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:926, in Function._call(self, *args, **kwds)
    923   self._lock.release()
    924   # In this case we have created variables on the first call, so we run the
    925   # defunned version which is guaranteed to never create variables.
--> 926   return self._no_variable_creation_fn(*args, **kwds)  # pylint: disable=not-callable
    927 elif self._variable_creation_fn is not None:
    928   # Release the lock early so that multiple threads can perform the call
    929   # in parallel.
    930   self._lock.release()

TypeError: 'NoneType' object is not callable

Where I'm wrong?


Solution

  • I find the solution, the problem is that the dataset shall be batchified:

    cds = tf.data.Dataset.from_generator(cd_gen, output_signature = cos)
    cds = cds.batch(512)
    print(cds.element_spec)
    
    (TensorSpec(shape=(None, 14, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 14, 3), dtype=tf.float32, name=None))
    

    The I need to adjust the fit call:

    history=model.fit(cds, epochs=1, steps_per_epoch=256)
    256/256 [==============================] - 35s 135ms/step - loss: 5.0874 - mean_squared_error: 5.0874