Search code examples
tensorflowbatchsize

Error using batch_size inside custom TensorFlow layer


I want to broadcast a vector to a batch of vectors and concatenate. I'm running into issues with using the batch_size in a custom layer. At build time, the batch_size (first dimension of the first input) is None. At call time, the batch_size is a int.

Here's the code I'm working with.

import sys
import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import Layer, Input, Concatenate, UpSampling2D
from tensorflow.keras import Model

print(f"Python Version: {sys.version}")
print(f"Numpy Version: {np.version.version}")
print(f"Tensorflow Version: {tf.version.VERSION}")

class MyLayer(Layer):
    def __init__(self, **kwargs):
        super(MyLayer, self).__init__(**kwargs)

        self.concat = Concatenate(axis = 1)

    def call(self, inputs):
        x, y = inputs
        batchSize = tf.shape(x)[0]
        batchSize = 1 if batchSize is None else batchSize
        yShape = tf.shape(y)

        y = tf.reshape(y, (1, 1, yShape[1], 1))
        y = UpSampling2D(size = (batchSize, 1))(y)
        y = tf.reshape(y, (batchSize, yShape[1]))
        return self.concat([x, y])

inputsX = Input(shape = (4, ), dtype = tf.float32)
inputsY = Input(shape = (5, ), dtype = tf.float32)
outputs = MyLayer()([inputsX, inputsY])
model = Model(inputs = [inputsX, inputsY], outputs = outputs)
model.build(input_shape = ((None, 4), (1, 5)))
model.summary()

pointsX = tf.constant(np.reshape(range(12), [3, 4]), dtype = tf.float32)
pointsY = tf.constant(np.reshape(range(5),  [1, 5]), dtype = tf.float32)
print(model((pointsX, pointsY)))

The versions are the following.

Python Version: 3.8.10 (default, Jun 22 2022, 20:18:18) 
[GCC 9.4.0]
Numpy Version: 1.23.4
Tensorflow Version: 2.11.0

I expect to see a 3-by-9 tensor with values from pointsX in the left four columns and values from pointsY on the right five columns of each row. Instead, I get the following error.

ValueError: The `size` argument must be a tuple of 2 integers. Received: (<tf.Tensor 'my_layer/strided_slice:0' shape=() dtype=int32>, 1)including element Tensor("my_layer/strided_slice:0", shape=(), dtype=int32) of type <class 'tensorflow.python.framework.ops.Tensor'>


Call arguments received by layer "my_layer" (type MyLayer):
  • inputs=['tf.Tensor(shape=(None, 4), dtype=float32)', 'tf.Tensor(shape=(None, 5), dtype=float32)']

How can I perform this broadcast and concatenate operation?


Solution

  • I found a low level function that does what I need.

    import sys
    import numpy as np
    import tensorflow as tf
    
    from tensorflow.keras.layers import Layer, Input, Concatenate, UpSampling2D
    from tensorflow.keras import Model
    
    print(f"Python Version: {sys.version}")
    print(f"Numpy Version: {np.version.version}")
    print(f"Tensorflow Version: {tf.version.VERSION}")
    
    class MyLayer(Layer):
        def __init__(self, **kwargs):
            super(MyLayer, self).__init__(**kwargs)
    
            self.concat = Concatenate(axis = 1)
    
        def call(self, inputs):
            x, y = inputs
            batchSize = tf.shape(x)[0]
            yShape = tf.shape(y)
    
            y = tf.reshape(y, (1, yShape[1]))
            y = tf.broadcast_to(y, [batchSize, yShape[1]])
            return self.concat([x, y])
    
    inputsX = Input(shape = (4, ), dtype = tf.float32)
    inputsY = Input(shape = (5, ), dtype = tf.float32)
    outputs = MyLayer()([inputsX, inputsY])
    model = Model(inputs = [inputsX, inputsY], outputs = outputs)
    model.build(input_shape = ((None, 4), (1, 5)))
    model.summary()