I came across the following code and was wondering what exactly does keras.layers.concatenate
do in this case.
Best Guess:
, y
learns based on every pixel(kernel_size=1
learns based on every pixel of the activation map
of y
learns based on an area of 3x3 pixels of activation map
of y
puts y1
and y3
together, meaning total filters
is now the sum of filters in y1
Any help is greatly appreciated.
def fire(x, squeeze, expand):
y = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)
y1 = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
y1 = BatchNormalization(momentum=bnmomemtum)(y1)
y3 = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
y3 = BatchNormalization(momentum=bnmomemtum)(y3)
return concatenate([y1, y3])
def fire_module(squeeze, expand):
return lambda x: fire(x, squeeze, expand)
x = Input(shape=[144, 144, 3])
y = BatchNormalization(center=True, scale=False)(x)
y = Activation('relu')(y)
y = Conv2D(kernel_size=5, filters=16, padding='same', use_bias=True, activation='relu')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)
y = fire_module(16, 32)(y)
y = MaxPooling2D(pool_size=2)(y)
To be a little more specific, why not have this:
# why not this?
def fire(x, squeeze, expand):
y = Conv2D(filters=squeeze, kernel_size=1, activation='relu', padding='same')(x)
y = BatchNormalization(momentum=bnmomemtum)(y)
y = Conv2D(filters=expand//2, kernel_size=1, activation='relu', padding='same')(y)
y = BatchNormalization(momentum=bnmomemtum)(y)
y = Conv2D(filters=expand//2, kernel_size=3, activation='relu', padding='same')(y)
y = BatchNormalization(momentum=bnmomemtum)(y)
return y
I'm citing @parsethis from this stack question when he explained concatenation, this is what it does if a is concatenated to b (results are joined together) :
a b c
a b c g h i a b c g h i
d e f j k l d e f j k l
The documentation says that it simply returns a tensor containing the concatenation of all inputs, provided they have share one dimension (i.e. same length or witdh, depending on axis)
What happened in your case seems like this :
\ |
Y3 Y1
I hope I was clear enough