I want to broadcast a vector to a batch of vectors and concatenate. I'm running into issues with using the batch_size in a custom layer. At build time, the batch_size (first dimension of the first input) is None. At call time, the batch_size is a int.
Here's the code I'm working with.
import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Concatenate, UpSampling2D
from tensorflow.keras import Model
print(f"Python Version: {sys.version}")
print(f"Numpy Version: {np.version.version}")
print(f"Tensorflow Version: {tf.version.VERSION}")
class MyLayer(Layer):
def __init__(self, **kwargs):
super(MyLayer, self).__init__(**kwargs)
self.concat = Concatenate(axis = 1)
def call(self, inputs):
x, y = inputs
batchSize = tf.shape(x)[0]
batchSize = 1 if batchSize is None else batchSize
yShape = tf.shape(y)
y = tf.reshape(y, (1, 1, yShape[1], 1))
y = UpSampling2D(size = (batchSize, 1))(y)
y = tf.reshape(y, (batchSize, yShape[1]))
return self.concat([x, y])
inputsX = Input(shape = (4, ), dtype = tf.float32)
inputsY = Input(shape = (5, ), dtype = tf.float32)
outputs = MyLayer()([inputsX, inputsY])
model = Model(inputs = [inputsX, inputsY], outputs = outputs)
model.build(input_shape = ((None, 4), (1, 5)))
model.summary()
pointsX = tf.constant(np.reshape(range(12), [3, 4]), dtype = tf.float32)
pointsY = tf.constant(np.reshape(range(5), [1, 5]), dtype = tf.float32)
print(model((pointsX, pointsY)))
The versions are the following.
Python Version: 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0]
Numpy Version: 1.23.4
Tensorflow Version: 2.11.0
I expect to see a 3-by-9 tensor with values from pointsX in the left four columns and values from pointsY on the right five columns of each row. Instead, I get the following error.
ValueError: The `size` argument must be a tuple of 2 integers. Received: (<tf.Tensor 'my_layer/strided_slice:0' shape=() dtype=int32>, 1)including element Tensor("my_layer/strided_slice:0", shape=(), dtype=int32) of type <class 'tensorflow.python.framework.ops.Tensor'>
Call arguments received by layer "my_layer" (type MyLayer):
• inputs=['tf.Tensor(shape=(None, 4), dtype=float32)', 'tf.Tensor(shape=(None, 5), dtype=float32)']
How can I perform this broadcast and concatenate operation?
I found a low level function that does what I need.
import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Concatenate, UpSampling2D
from tensorflow.keras import Model
print(f"Python Version: {sys.version}")
print(f"Numpy Version: {np.version.version}")
print(f"Tensorflow Version: {tf.version.VERSION}")
class MyLayer(Layer):
def __init__(self, **kwargs):
super(MyLayer, self).__init__(**kwargs)
self.concat = Concatenate(axis = 1)
def call(self, inputs):
x, y = inputs
batchSize = tf.shape(x)[0]
yShape = tf.shape(y)
y = tf.reshape(y, (1, yShape[1]))
y = tf.broadcast_to(y, [batchSize, yShape[1]])
return self.concat([x, y])
inputsX = Input(shape = (4, ), dtype = tf.float32)
inputsY = Input(shape = (5, ), dtype = tf.float32)
outputs = MyLayer()([inputsX, inputsY])
model = Model(inputs = [inputsX, inputsY], outputs = outputs)
model.build(input_shape = ((None, 4), (1, 5)))
model.summary()