Search code examples
pythontensorflowkerasrepeattensor

Repeat specific elements of a tensor using Keras


I am using a 4-D tensor of shape=(N, 2, 127, 52)

I used:

tf.keras.backend.repeat_elements(tensor, 2, axis=3)

This duplicate the last axis size from 52 to 104 by repeating each value:

shape=(N, 2, 127, 104)

Now I want to the same but only with the last 10 elements from the third axis having now:

shape=(N, 2, 127, 114)

I am also looking how to add an extra "column" by adding a zero vector in the middle of the last axis tensor resulting:

shape=(N, 2, 127, 115)

How can I do this?


Solution

  • I think using tf.concat would be a simple way:

    import tensorflow as tf
    
    N = 2
    tensor = tf.random.normal((N, 2, 127, 52))
    tensor = tf.repeat(tensor, 2, axis=3)
    
    # (N, 2, 127, 114)
    tensor = tf.concat([tensor, tensor[..., tf.shape(tensor)[-1]-10:]], axis=-1)
    
    # (N, 2, 127, 115)
    middle = tf.shape(tensor)[-1]//2
    tensor = tf.concat([tensor[..., :middle], tf.zeros((N, 2, 127, 1)), tensor[..., middle:]], axis=-1)
    
    print(tensor.shape)
    
    (2, 2, 127, 115)