Search code examples
tensorflowtensorflow2.0convolution

Generate indices equivalent to convolution kernel inputs in tensorflow


I want to add some preprocessing within the data that goes into each 2D convolution kernel in tensorflow. If the input is NHWC of size (N, H, W, 5) and convolution kernel is of size (3, 3) with stride (1, 1) and padding “valid” (I am going to apply the convolution later as explained below), I first want to obtain the tensor of size (N, H-2, W-2, 5*9) where the indices along dims 1 and 2 represent the spatial locations of the convolution kernels, and dim 3 is of size 5*9=45 with each index corresponding to a input cell/channel that will go into the convolution, with all kernel inputs arranged along the channel dimension. After obtaining this tensor, I want to apply some transformation along the channel dimension (dim 3), and then call the convolution with kernel of size (1, 1) and stride (1, 1). In essence, I want to collect all inputs that will go into each convolution kernel first, do something with them, and then call the convolution (it’s going to have kernel of size (1, 1) because I already collected all kernel inputs along the channel axis). I’m interested if there is an operation in tensorflow that would help me to generate the indices for these windows, which I perhaps could later use with tf.gather_nd(). Any other ideas/solutions welcome. Thanks!


Solution

  • I am pasting my solution below:

    N = 32
    H = 101
    W = 101
    C = 5
    
    # input batch data
    batch = tf.constant(
        value=range(N*H*W*C), 
        shape=[N, H, W, C], 
        dtype=tf.float32)
    
    print(f"Batch shape: {batch.shape}")
    print(f"Batch tensor size: {tf.math.reduce_prod(batch.shape).numpy()}")
    print()
    
    def prepare_conv_input(batch):
    
        i_to_stack = []
        for i in range(0, H-2, 2):
    
            j_to_stack = []
            for j in range(0, W-2, 2):
    
                inp = tf.slice(
                    batch,
                    begin=[0, i, j, 0],
                    size=[N, 3, 3, C])
    
                j_to_stack.append(
                    tf.reshape(
                        inp, 
                        (N, 3*3*C)))
    
            i_to_stack.append(
                tf.stack(j_to_stack, 1))
    
        return tf.stack(i_to_stack, 1)
    
    result = prepare_conv_input(batch)
    
    print(f"Prepared shape: {result.shape}")
    print(f"Prepared tensor size: {tf.math.reduce_prod(result.shape).numpy()}")
    print()
    

    And the result is:

    Batch shape: (32, 101, 101, 5)
    Batch tensor size: 1632160
    
    Prepared shape: (32, 50, 50, 45)
    Prepared tensor size: 3600000
    

    The data size is increased due to the overlap between the neighbouring convolutions.