I'm trying to implement a 3D convnet followed by LSTM layer for sequence generation using 3D images as input , on Keras with Tensorflow backend.
I would like to start training with weights of an existing pre-trained model in order to avoid common issues with random initialization .
In order to start with a basic example, I took VGG-16 and I implemented a "3D" version of this network (without the FC layers):
img_input = Input((100,80,80,3))
x = Conv3D(64, (3, 3 ,3), activation='relu', padding='same', name='block1_conv1')(img_input)
x = Conv3D(64, (3, 3 ,3), activation='relu', padding='same', name='block1_conv2')(x)
x = MaxPooling3D((1, 2, 2), strides=(1, 2, 2), name='block1_pool')(x)
x = Conv3D(128, (3, 3 ,3), activation='relu', padding='same', name='block2_conv1')(x)
x = Conv3D(128, (3, 3 ,3), activation='relu', padding='same', name='block2_conv2')(x)
x = MaxPooling3D((1, 2 ,2), strides=(1,2, 2), name='block2_pool')(x)
x = Conv3D(256, (3, 3 ,3), activation='relu', padding='same', name='block3_conv1')(x)
x = Conv3D(256, (3, 3 , 3), activation='relu', padding='same', name='block3_conv2')(x)
x = Conv3D(256, (3, 3, 3), activation='relu', padding='same', name='block3_conv3')(x)
x = MaxPooling3D((1, 2 ,2), strides=(1,2, 2), name='block3_pool')(x)
x = Conv3D(512, (3, 3 ,3), activation='relu', padding='same', name='block4_conv1')(x)
x = Conv3D(512, (3, 3 ,3), activation='relu', padding='same', name='block4_conv2')(x)
x = Conv3D(512, (3, 3 ,3), activation='relu', padding='same', name='block4_conv3')(x)
x = MaxPooling3D((1, 2 ,2), strides=(1, 2, 2), name='block4_pool')(x)
x = Conv3D(512, (3, 3 ,3), activation='relu', padding='same', name='block5_conv1')(x)
x = Conv3D(512, (3, 3 ,3), activation='relu', padding='same', name='block5_conv2')(x)
x = Conv3D(512, (3, 3 ,3), activation='relu', padding='same', name='block5_conv3')(x)
x = MaxPooling3D((1, 2 ,2), strides=(1, 2, 2), name='block5_pool')(x)
So I would like to know how can I load the weights of the pretrained VGG-16 into each one of the 100 slices (my 3D images are composed by 100 80x80 rgb slices) ,
Any advise you can give to me would be useful,
Thanks
This depends on what you are looking to do in your application. If you are just looking to process the 3D image in terms of slices, then you can define a TimeDistributed VGG16 network (Conv2D instead of Conv3D) would be the way to go.
The model then becomes something like this for each layer you define above:
img_input = Input((100,80,80,3))
x = TimeDistributed(Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', trainable=False))(img_input)
x = TimeDistributed(Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', trainable=False))(x)
x = TimeDistributed((MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool', trainable=False)(x)
...
...
Note that I have included the option 'trainable=False' over here. This is pretty useful if you only want to train the deeper layers and freeze the lower layers with the well trained wights of VGG.
To load the VGG weights for the model, you can then use the load_weights function of Keras.
model.load_weights(filepath, by_name=True)
If you set the layer names which you do not want to train to be the same as what is defined in the VGG16, then you can simply load those layers by name over here.
However, spatio-temporal feature learning is something that can be potentially done better much by using 3D ConvNets. If this is the basis of your application, then you cannot directly import VGG16 weights in to a Conv3D model, because the number of parameters in each layer now increases because the filter went from being a 3*3 to a 3*3*3 for example.
You could still load the weights layer by layer to the model by considering which patch of 3*3 from the 3*3*3 would be most suitable for initialization with VGG16 weights. set_weights() function takes as input a list of numpy arrays (for the kernel weights and the bias respectively). You can extract each layer weight from VGG16 and then construct a new numpy array for an equivalent Conv3D weight matrix and feed it to your Conv3D model.
But I would encourage you to look at existing literature and models for processing 3D images to see if those can give you the better initialization using transfer learning.
For example, C3D is one such popular model. ShapeNet and Pascal3D are popular 3D datasets.
This discussion on how to process video data might also be useful to give you better insights on how to proceed.