Search code examples
kerasneural-networkresnet

Graph disconnected when layers added on top of ResNet network


I am trying to change the input shape of ResNet50 network. I Need inputs with more than 3 channels. The ResNet application works when you specify the input shape without loading imagenet weights but I would like to use weights of imagenet to avoid a long training phase.

I am aware that weights of imagenet is for input shape with three channels but theoretically by cutting the head of the network and adding a new input layer this should work.

I tried to remove the head layer but I have some problems saying number of filters is different from 3

ValueError: number of input channels does not match corresponding dimension of filter, 6 != 3

    model=keras.applications.resnet50.ResNet50(include_top=False,
               input_shape(200,200,3),weights='imagenet')
    model.layers.pop(0)
    model.layers.pop(0)
    model.layers.pop()
    X_input = Input((200,200,6), name='input_1')
    X = ZeroPadding2D((3, 3), name='conv1_pad')(X_input)
    model = Model(inputs=X, outputs=model(X))
    model.summary()

I think it is possible to change the number of channels of the input shape and still using weights from imagenet but the method that I tried seems to be wrong.


Solution

  • I'm not sure a keras model supports list operations on its layers, seems that popping layers doesn't make it forget its expected input size.

    You could initialize a new resnet with your input shape and manually load Imagenet weights to all layers except the first 3 which expect 3 channels in it's input tensor.

    borrowing a few lines from keras.applications.resnet50 would result in something like this:

    import h5py
    import keras
    from keras_applications.resnet50 import WEIGHTS_PATH_NO_TOP
    
    input_tensor = keras.Input((200, 200, 6))
    resnet = keras.applications.ResNet50(
        input_tensor=input_tensor, weights=None, include_top=False
    )
    
    weights_path = keras.utils.get_file(
        'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
        WEIGHTS_PATH_NO_TOP,
        cache_subdir='models',
        md5_hash='a268eb855778b3df3c7506639542a6af')
    
    with h5py.File(weights_path, 'r') as f:
        for layer in resnet.layers[3:]:
            if layer.name in f:
                layer.set_weights(f[layer.name].values())
    

    With that said, the kind of transferred learning you are trying to do is not very common and I'm really curios if it works. can you please update if it actually converged faster?