Search code examples
tensorflowkerastensorflow2.0tensorflow-datasetsbatchsize

How to get batch_size in call() function in TF2?


I'm trying to get batch_size in call() function in TF2 model. However, I cannot get it because all the methods I know returns None or Tensor instead of dimension tuple.

Here is a short example

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
    
    def call(self, x):
        print(len(x))
        print(x.shape)
        print(tf.size(x))
        print(np.shape(x))
        print(x.get_shape())
        print(x.get_shape().as_list())
        print(tf.rank(x))
        print(tf.shape(x))
        print(tf.shape(x)[0])
        print(tf.shape(x)[1])        
        return tf.random.uniform((2, 10))


m = MyModel()
m.compile(optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
m.fit(np.array([[1,2,3,4], [5,6,7,8]]), np.array([0, 1]), epochs=1)

The output is:

Tensor("my_model_26/strided_slice:0", shape=(), dtype=int32)
(None, 4)
Tensor("my_model_26/Size:0", shape=(), dtype=int32)
(None, 4)
(None, 4)
[None, 4]
Tensor("my_model_26/Rank:0", shape=(), dtype=int32)
Tensor("my_model_26/Shape_2:0", shape=(2,), dtype=int32)
Tensor("my_model_26/strided_slice_1:0", shape=(), dtype=int32)
Tensor("my_model_26/strided_slice_2:0", shape=(), dtype=int32)

1/1 [==============================] - 0s 1ms/step - loss: 3.1796 - accuracy: 0.0000e+00

I fed (2,4) numpy array as input and (2, ) as target to the model in this example. But as you can see, I cannot get batch_size in call() function.

The reason I need it is because I have to iterate tensors for batch_size which is dynamic in my real model.

For example, if the dataset size is 10 and batch size is 3, then the last batch size in last batch would be 1. So, I have to know batch size dynamically.

Can anyone help me?


  • Tensorflow 2.3.3
  • CUDA 10.2
  • python 3.6.9

Solution

  • It's because you're using TensorFlow (that's mandatory since Keras is now inside TensorFlow), and by using TensorFlow you need to be aware of the "compilation" of the dynamic graph into a static-graph.

    In short, your call method is (under the hood) decorated with the @tf.function decorator.

    This decorator:

    1. Traces the python function execution
    2. Converts the python operation in TensorFlow operations (e.g. if a > b becomes tf.cond(tf.greater(a,b), something, something_else))
    3. Creates a tf.Graph (the static graph)
    4. Executes the static graph just created.

    Al your print calls are executed during the first step (the python execution tracing), that's why even if you train your model you see the output only 1 time.

    That said, to get the runtime (dynamic shape) of a tensor, you must use tf.shape(x), the batch size is just batch_size = tf.shape(x)[0]

    Please note that if you want to see the shape (using print) you can't use print, but you must use tf.print.

    import numpy as np
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras.models import Model
    
    
    class MyModel(Model):
        def __init__(self):
            super(MyModel, self).__init__()
    
        def call(self, x):
    
            shape = tf.shape(x)
            batch_size = shape[0]
    
            tf.print(shape, batch_size)
    
            return tf.random.uniform((2, 10))
    
    
    m = MyModel()
    m.compile(
        optimizer="Adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
    )
    m.fit(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), np.array([0, 1]), epochs=1)
    

    More information about static and dynamic shapes: https://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

    More info about the tf.function behavior: https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/

    Note: I wrote these articles.