Search code examples
kerasdeep-learningconv-neural-networkkeras-layertf.keras

What does keras.layers.concatenate do


I came across the following code and was wondering what exactly does keras.layers.concatenate do in this case.

Best Guess:

  1. In fire_module(), y learns based on every pixel(kernel_size=1)
  2. y1 learns based on every pixel of the activation map of y(kernel_size=1)
  3. y3 learns based on an area of 3x3 pixels of activation map of y(kernel_size=3)
  4. concatenate puts y1 and y3 together, meaning total filters is now the sum of filters in y1 andy3
  5. This concatenation is an average of, learning based on every pixel, learning based on 3x3, both based on a previous activation map based on every pixel, making model better?

Any help is greatly appreciated.

def fire(x, squeeze, expand):
    y  = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
    y  = BatchNormalization(momentum=bnmomemtum)(y)
    y1 = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
    y1 = BatchNormalization(momentum=bnmomemtum)(y1)
    y3 = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
    y3 = BatchNormalization(momentum=bnmomemtum)(y3)
    return concatenate([y1, y3])

def fire_module(squeeze, expand):
    return lambda x: fire(x, squeeze, expand)
x = Input(shape=[144, 144, 3])
y = BatchNormalization(center=True, scale=False)(x)
y = Activation('relu')(y)
y = Conv2D(kernel_size=5, filters=16, padding='same', use_bias=True, activation='relu')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)

y = fire_module(16, 32)(y)
y = MaxPooling2D(pool_size=2)(y)

Edit:

To be a little more specific, why not have this:

# why not this?
def fire(x, squeeze, expand):
    y  = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
    y  = BatchNormalization(momentum=bnmomemtum)(y)
    y = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
    y = BatchNormalization(momentum=bnmomemtum)(y)
    y = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
    y = BatchNormalization(momentum=bnmomemtum)(y)
    return y


Solution

  • I'm citing @parsethis from this stack question when he explained concatenation, this is what it does if a is concatenated to b (results are joined together) :

        a        b         c
    a b c   g h i    a b c g h i
    d e f   j k l    d e f j k l
    

    The documentation says that it simply returns a tensor containing the concatenation of all inputs, provided they have share one dimension (i.e. same length or witdh, depending on axis)

    What happened in your case seems like this :

    Y 
     \
      Y1----
       \    |
        Y3  Y1
    

    I hope I was clear enough