Search code examples
pythontensorflowtensorflow2.0keras-layertf.keras

How to use batch size to create a tensor within a custom TensorFlow Layer


I'm creating a custom TF layer and inside it I need to create a tensor of ones with something like this

class MyLayer(Layer):
  def __init__(self, config, **kwargs):
    super(MyLayer, self).__init__(**kwargs)
    ....

  def call(self, x):
    B, T, C = x.shape.as_list()
    ...
    ones = tf.ones((B, T, C))
    ...
    # output projection
    y = ...
    return y

Now the problem is the B (which is batch size) is None when the layer is evaluated, which is causing the tf.ones to fail with the following error:


ValueError: in user code:

    <ipython-input-69-f3322a54c05c>:29 call  *
        ones = tf.ones((B, T, C))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper  **
        return target(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:3080 ones
        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/profiler/trace.py:163 wrapped
        return func(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1535 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py:356 _tensor_shape_tensor_conversion_function
        "Cannot convert a partially known TensorShape to a Tensor: %s" % s)

    ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 8, 128)

How can I get this working?


Solution

  • If you just want to get a tensor with the same shape as x then you can use tf.ones_like. Something like this:

    class MyLayer(Layer):
    
      ....
    
      def call(self, x):
        ones = tf.ones_like(x)
    
        ...
    
        # output projection
        y = ...
        return y
    

    which doesnt need to know the shape of x till runtime.

    In general, however, we might need to know the shape of the input before runtime in which case we can implement the build() method in our layer which takes input_shape as a parameter and is called when we compile our model.

    Example copied from docs here:

    class Linear(keras.layers.Layer):
        def __init__(self, units=32):
            super(Linear, self).__init__()
            self.units = units
    
        def build(self, input_shape):
            self.w = self.add_weight(
                shape=(input_shape[-1], self.units),
                initializer="random_normal",
                trainable=True,
            )
            self.b = self.add_weight(
                shape=(self.units,), initializer="random_normal", trainable=True
            )
    
        def call(self, inputs):
            return tf.matmul(inputs, self.w) + self.b