Search code examples
pythontheanokeras

Zero-padding for ResNet shortcut connections when channel number increase


I would like to implement the ResNet network in Keras with the shortcut connections that add zero entries when features/channels dimensions mismatch according to the original paper:

When the dimensions increase (dotted line shortcuts in Fig. 3), we consider two options: (A) The shortcut still performs identity mapping, with extra zero entries padded for increasing dimensions ... http://arxiv.org/pdf/1512.03385v1.pdf

However wasn't able to implement it and I can't seem to find an answer on the web or on the source code. All the implementations that I found use the 1x1 convolution trick for shortcut connections when dimensions mismatch.

The layer I would like to implement would basically concatenate the input tensor with a tensor with an all zeros tensor to compensate for the dimension mismatch.

The idea would be something like this, but I could not get it working:

def zero_pad(x, shape):
    return K.concatenate([x, K.zeros(shape)], axis=1)

Does anyone has an idea on how to implement such a layer ?

Thanks a lot


Solution

  • The question was answered on github: https://github.com/fchollet/keras/issues/2608

    It would be something like this:

    from keras.layers.convolutional import MaxPooling2D
    from keras.layers.core import Lambda
    from keras import backend as K
    
    
    def zeropad(x):
        y = K.zeros_like(x)
        return K.concatenate([x, y], axis=1)
    
    
    def zeropad_output_shape(input_shape):
        shape = list(input_shape)
        assert len(shape) == 4
        shape[1] *= 2
        return tuple(shape)
    
    
    def shortcut(input_layer, nb_filters, output_shape, zeros_upsample=True):
        # TODO: Figure out why zeros_upsample doesn't work in Theano
        if zeros_upsample:
            x = MaxPooling2D(pool_size=(1,1),
                                 strides=(2,2),
                                 border_mode='same')(input_layer)
            x = Lambda(zeropad, output_shape=zeropad_output_shape)(x)
        else:
            # Options B, C in ResNet paper...