Search code examples
pythonchainer

Add extra output to existing Chainer network


Let's say I create a simple fully connected network:

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import Sequential

model = Sequential(
    L.Linear(n_in, n_hidden),
    F.relu,
    L.Linear(n_hidden, n_hidden),
    F.relu,
    L.Linear(n_hidden, n_out)
)

# Compute the forward pass
y = model(x)

I want to train this model with n_out outputs, then, after it is trained, to add extra outputs before fine-tuning the network.

I have found ways to remove the last layer in order to retrain a new last layer, however this is not what I want: I want to keep the weights of the existing outputs. The weights of the new outputs would be initialized randomly.


Solution

  • How about introducing an additional linear layer L.Linear(n_hidden, n_extra_out) (without removing any of the existing ones) where n_extra_out is the number of additional outputs. You can then extract the output from the last F.relu (you might want to consider replacing the Sequential object with an instance of a chainer.Chain implementation for this, similar to this example https://github.com/chainer/chainer/blob/master/examples/mnist/train_mnist.py#L16) and pass it as inputs to both your pretrained last linear layer as well as this new layer. The two outputs can then be concatenated using F.concat.