Search code examples
pythonmachine-learningkerasmax-pooling

merge multiple keras max pooling layers


I am new to keras.

My goal is to have total of 4 max pooling layers. All of them take same input with shape (N, 256). The first layer does global max pooling and give 1 output. The second layer with N / 2 pooling size and N / 2 stride, gives 2 outputs. The third gives 4 outputs and the fourth gives 8 outputs. Here is my code.

    test_x = np.random.rand(N, 256, 1)

    model = Sequential()

    input1 = Input(shape=test_x.shape, name='input1')
    input2 = Input(shape=test_x.shape, name='input2')
    input3 = Input(shape=test_x.shape, name='input3')
    input4 = Input(shape=test_x.shape, name='input4')

    max1 = MaxPooling2D(pool_size=(N, 256), strides=N)(input1)
    max2 = MaxPooling2D(pool_size=(N / 2, 256), strides=N / 2)(input2)
    max3 = MaxPooling2D(pool_size=(N / 4, 256), strides=N / 4)(input3)
    max4 = MaxPooling2D(pool_size=(N / 8, 256), strides=N / 8)(input4)

    mrg = Merge(mode='concat')([max1, max2, max3, max4])

After creating 4 max pooling layers, I try to merge them together, but keras gives this error.

ValueError: Dimension 1 in both shapes must be equal, but are 4 and 8 for 'merge_1/concat' (op: 'ConcatV2') with input shapes: [?,1,1,1], [?,2,1,1], [?,4,1,1], [?,8,1,1], [] and with computed input tensors: input[4] = <3>.

How can I solve this issue? Is merging the correct way to achieve my goal in keras?


Solution

  • For concatenation, all dimensions must have the same number of elements, except for the concat dimension itself.

    As you can see, your results have shape:

    (?, 1, 1, 1)    
    (?, 2, 1, 1)    
    (?, 4, 1, 1)    
    (?, 8, 1, 1)    
    

    Naturally, the only possible way to concatenate them is in the second axis (axis=1)

    mrg = Concatenate(axis=1)([max1,max2,max3,max4])
    

    But notice that (unless you have specific reasons for that and know exaclty what you're doing) this will result in a very weird image, since you're concatenating in a spatial dimension, not in a channel dimension.