Search code examples
pythontensorflowkerasnlptensorflow2.0

Why does this tf.keras model behave differently than expected on sliced inputs?


I'm coding a Keras model which, given (mini)-batches of tensors, applies the same layer to each of their elements. Just to give a little bit of context, I'm giving as input groups (of fixed size) of strings, which must be encoded one by one by an encoding layer. Thus, the input size comprising the (mini)-batch size is (None, n_sentences_per_sample, ), where n_sentences_per_sample is a fixed value known a prior.

To do so, I use this custom function when creating the model in the Functional API:

def _branch_execute(layer_in: keras.layers.Layer, sublayer: [keras.layers.Layer, Callable], **args) -> keras.layers.Layer:
    instance_cnt = layer_in.shape[1]
    sliced_inputs = [tf.keras.layers.Lambda(lambda x: x[:, i])(layer_in) for i in range(instance_cnt)]
    branch_layers = [sublayer(**{**{'layer_in': sliced_inputs[i]}, **args}) for i in range(instance_cnt)]
    expand_layer = tf.keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))            
    expanded_layers = [expand_layer(branch_layers[i]) for i in range(instance_cnt)]
    concated_layer = tf.keras.layers.Concatenate(axis=1)(expanded_layers)    
    return concated_layer

which I use in this way

model_input = keras.layers.Input(shape=(self.max_sents, ),
                                 dtype=tf.string,
                                 )
sentences_embedded = self._branch_execute(model_input, self._get_nnlm_128)
model = keras.models.Model(model_input, sentences_embedded)

where self._get_nnlm_128() is just a function which returns the result of applying a cached pretrained embedding layer to the input, i.e.

def _get_nnlm_128(self, layer_in: keras.layers.Layer, trainable: bool = False):
    if 'nnlm_128_layer_shared' not in self.shared_layers_cache:
        self.shared_layers_cache['nnlm_128_layer_shared'] = {
            'encoder': hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim128-with-normalization/2", trainable=trainable)
        }
    shared_layers = self.shared_layers_cache['nnlm_128_layer_shared']
    encoder = shared_layers['encoder'](layer_in)
    return encoder

The problem I have is as follows:

  1. If I call self._branch_execute(input_tensor, self._get_nnlm_128) where input tensor is just a well-shaped tensor, it works perfectly;
  2. If I call model (whether directly or through .predict(), whether compiled or not) on the same input_tensor, I get a repeated result for every sentence in the sample (weirdly, it is the output corresponding to the LAST sentence, repeated - see below);

Just as an example (tho I have the same issue with every possible input), let us consider an input_tensor composed of 7 sentences (7 strings), reshaped as (1, 7, ) to include the minibatch axis. The result of 1) is

[[[ 0.216900051 0.037066862 0.163929373 ... 0.050420273 0.082906663 0.059960182],
  [ 0.531883411 -0.000807280 0.107559107 ... -0.079948671 -0.020143294 0.007032406],
  ...
  [ 0.15044811 0.00890037  0.10413752 ... -0.05391502 -1.2199926 -0.13466084]]]

where I get 7 vectors/embeddings of size 128, all different from each other as expected; The result of 2) is, oddly enough,

[[[ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084],  
  [ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084], 
  [ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084],
  [ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084],
  [ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084],
  [ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084],
  [ 0.15044811  0.00890037  0.10413752 ... -0.05391502 -0.12199926 -0.13466084]]]

where I get 7 times the same vector (as I said, it always corresponds to the last one repeated for all the sentences). I took the results from an actual run.

Among the many trials I made, I tried to output model_input from the model, which works great, i.e. it corresponds to the input strings. The embedding model is taken directly from Tensorflow hub, so it should not have problems. Additionally, the same behavior is observed with any other embedding layer, whether custom or pretrained. I think therefore that the problem may be in the _branch_execute() function, but I have no idea of what the issue could be given that used alone it works correctly. Maybe it can have to do with some peculiar broadcasting behavior inside of keras models, but I don't know how to test it, let alone how to solve it.

I would appreciate any suggestion that you may have about why this issue is there and how to fix it. I'm not an expert of Tensorflow, so maybe I'm just misjudging something (in case, forgive me!). I'll be glad to share more info as needed to help solve the problem. Thanks a lot :)


Solution

  • I finally came to the conclusion that the problem was into the line

    sliced_inputs = [tf.keras.layers.Lambda(lambda x: x[:, i])(layer_in) for i in range(instance_cnt)]
    

    which apparently does not work as expected (I'm running Tensorflow 2.4.0, but I got the same issue also with Tensorflow 2.5.0-nightly). I just substituted the Lambda layer with a custom layer that does exactly the same thing, i.e.

    class Slicer(keras.layers.Layer):
        def __init__(self, i, **kwargs):
            self.i = i
            super(Slicer, self).__init__(**kwargs)
    
        def call(self, inputs, **kwargs):
            return inputs[:, self.i]
    

    which I then used in the _branch_execute() function just like this

    def _branch_execute(self, layer_in: keras.layers.Layer, sublayer: [keras.layers.Layer, Callable], **args) -> keras.layers.Layer:
        instance_cnt = layer_in.shape[1]
        sliced_inputs = [Slicer(i)(layer_in) for i in range(instance_cnt)]
        branch_layers = [sublayer(**{**{'layer_in': sliced_inputs[i]}, **args}) for i in range(instance_cnt)]
        expand_layer = tf.keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))
        expanded_layers = [expand_layer(branch_layers[i]) for i in range(instance_cnt)]
        concated_layer = tf.keras.layers.Concatenate(axis=1)(expanded_layers)
        return concated_layer
    

    I'm not sure if this is the best option to solve the problem, but it seems pretty neat and it works well.

    Since this is probably an unexpected behavior of the Lambda layer, I'll be opening an issue on the Tensorflow github and post here the reply for reference.