Search code examples
pythontensorflowkerasneural-networklayer

Determine batch size during `tensorflow.keras` Custom Class `call` method


I already asked this question here, but I thought StackOverflow would have more traffic/people that might know the answer.

I'm building a custom keras Layer similar to an example found here. I want the call method inside the class to be able to know what the batch_size of the inputs data flowing through the method is, but the inputs.shape is showing as (None, 3) during model prediction. Here's a concrete example:

I initialize a simple data set like this:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model

# Create fake data to use for model testing
n = 1000
np.random.seed(123)
x1 = np.random.random(n)
x2 = np.random.normal(0, 1, size=n)
x3 = np.random.lognormal(0, 1, size=n)

X = pd.DataFrame(np.concatenate([
    np.reshape(x1, (-1, 1)),
    np.reshape(x2, (-1, 1)),
    np.reshape(x3, (-1, 1)),
], axis=1))

Then I define a custom class to test/show what I'm talking about:

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)

        print(inputs)
        record_count, n = inputs.shape
        print(f'inputs.shape = {inputs.shape}')

        return inputs

Then, when I create a simple model and force it to do a forward pass...

input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])

... I get this output printed to the screen

model.predict(X.loc[:9, :])
Tensor("model_1/Cast:0", shape=(None, 3), dtype=float32)
inputs.shape = (None, 3)
1/1 [==============================] - 0s 28ms/step
Out[34]: 
array([[ 0.5335418 ,  0.7788839 ,  0.64132416],
       [ 0.2924202 , -0.08321562,  0.412311  ],
       [ 0.5118007 , -0.6822934 ,  1.1782378 ],
       [ 0.03780456, -0.19350041,  0.7637337 ],
       [ 0.86494124, -3.196387  ,  4.8535166 ],
       [ 0.26708454, -0.49397194,  0.91296834],
       [ 0.49734482, -1.6618049 ,  0.50054324],
       [ 0.8563762 ,  0.7956695 ,  0.29466265],
       [ 0.7682351 ,  0.86538637,  0.6633331 ],
       [ 0.85322225,  0.868021  ,  0.1776046 ]], dtype=float32)

You can see that during the model.predict call the inputs.shape prints out a value of (None, 3), but obviously that's not true since the call method returns an output with a shape of (10, 3). How can I capture the 10 value in this example while in the call method?

UPDATE 1

When I use tf.shape as suggested in the current answer, I can print the value to the screen, but I get an error when I try to capture that value in a variable.

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        record_count, n = tf.shape(inputs)
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs

This code causes an error on the record_count, ... line.

Traceback (most recent call last):
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-22-104d812c32e6>", line 1, in <module>
    test = TestClass()(input_layer)
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_class_4" (type TestClass).
in user code:
    File "<ipython-input-21-2dec1d5b9547>", line 12, in call  *
        record_count, n = tf.shape(inputs)
    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Call arguments received by layer "test_class_4" (type TestClass):
  • inputs=tf.Tensor(shape=(None, 3), dtype=float32)

I tried decorating the call method with @tf.function, but I get the same error.

UPDATE 2

I tried a couple other things and found that, oddly, tensorflow doesn't seem to like the tuple assignment. It seems to work fine if it's coded like this instead.

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        shape = tf.shape(inputs)
        record_count = shape[0]
        n = shape[1]
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs

Solution

  • TL;DR --> Use tf.shape(inputs)[0] if you want to capture dynamic batch size in call method, or you can just use static batch size which can be specified in model creation.


    Under the hood TensorFlow decorates call and __call__ (that's what call method calls) method with tf.function. Using print and .shape will not work as expected.

    With tf.function python codes are traced and converted to native TensorFlow operations. After that, a static graph is created, this is just an instance of tf.Graph. In the end, the operations are executed in that graph.

    Python's print function only considered in the first step only, so this is not the correct way to print things in graph mode (decorated with tf.function).

    Tensor shapes are dynamic in runtime so you need to use tf.shape(inputs)[0] which will give you the batch size for that batch.

    If you really want to see that 10 in call:

    class TestClass(tf.keras.layers.Layer):
        def __init__(self, **kwargs):
            super(TestClass, self).__init__(**kwargs)
    
        def get_config(self):
            config = super(TestClass, self).get_config()
            return config
    
        def call(self, inputs: tf.Tensor):
            if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
                inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
            tf.print("Dynamic batch size", tf.shape(inputs)[0])
            return inputs
    

    Running:

    input_layer = layers.Input(3)
    test = TestClass()(input_layer)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
    model = Model(input_layer, test)
    model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
    model.predict(X.loc[:9, :])
    

    Will return:

    Dynamic batch size 10
    1/1 [==============================] - 0s 65ms/step
    array([[ 6.9646919e-01, -1.0032653e-02,  3.7556963e+00],
           [ 2.8613934e-01, -8.4564441e-01,  9.9685013e-01],
           [ 2.2685145e-01,  9.1146064e-01,  6.5008003e-01],
           [ 5.5131477e-01, -1.3744969e+00,  8.6379850e-01],
           [ 7.1946895e-01, -5.4706562e-01,  3.1904945e+00],
           [ 4.2310646e-01, -7.5526608e-05,  5.2649558e-01],
           [ 9.8076421e-01, -1.2116680e-01,  7.4064606e-01],
           [ 6.8482971e-01, -2.0085855e+00,  5.3138912e-01],
           [ 4.8093191e-01, -9.2064655e-01,  8.1520426e-01],
           [ 3.9211753e-01,  1.6823435e-01,  1.2382457e+00]], dtype=float32)