Search code examples
pythonneural-networkchainer

Optimization Target must be a link


I have an autoencoder model of 4 linear layers written using chainer.Chain. Running the optimizer.setup line in Trainer section give me the following error:

TypeError                                 Traceback (most recent call 
last)
<ipython-input-9-a2aabc58d467> in <module>()
      8 
      9 optimizer = optimizers.AdaDelta()
---> 10 optimizer.setup(sda)
     11 
     12 train_iter = iterators.SerialIterator(train_data,batchsize)

/usr/local/lib/python3.6/dist-packages/chainer/optimizer.py in setup(self, 
link)
    415         """
    416         if not isinstance(link, link_module.Link):
--> 417             raise TypeError('optimization target must be a link')
    418         self.target = link
    419         self.t = 0

TypeError: optimization target must be a link

The link to class StackedAutoEncoder is as follows: StackAutoEncoder link

The link to class NNBase which is used to write class AutoEncoder is as follows: NNBase link

model = chainer.Chain(
    enc1=L.Linear(1764, 200),
    enc2=L.Linear(200, 30),
    dec2=L.Linear(30, 200),
    dec1=L.Linear(200, 1764)
)


sda = StackedAutoEncoder(model, gpu=0)
sda.set_order(('enc1', 'enc2'), ('dec2', 'dec1'))
sda.set_optimizer(Opt.AdaDelta)
sda.set_encode(encode)
sda.set_decode(decode)

from chainer import iterators, training, optimizers
from chainer import Link, Chain, ChainList

optimizer = optimizers.AdaDelta()
optimizer.setup(sda)

train_iter = iterators.SerialIterator(train_data,batchsize)
valid_iter = iterators.SerialIterator(test_data,batchsize)

updater = training.StandardUpdater(train_iter,optimizer)
trainer = training.Trainer(updater,(epoch,"epoch"),out="result")

from chainer.training import extensions
trainer.extend(extensions.Evaluator(valid_iter, sda, device=gpu))

Chain is made of Links. I want to understand why the optimizer is not recognizing the sda which is StackedAutoencoder(model)?


Solution

  • StackedAutoencoder inherits NNBase class, which inherits object class, so they are not chainer.Chain class.

    You can refer official example for how to define your own network. For example, MNIST example defines MLP as follows:

    class MLP(chainer.Chain):
    
        def __init__(self, n_units, n_out):
            super(MLP, self).__init__()
            with self.init_scope():
                # the size of the inputs to each layer will be inferred
                self.l1 = L.Linear(None, n_units)  # n_in -> n_units
                self.l2 = L.Linear(None, n_units)  # n_units -> n_units
                self.l3 = L.Linear(None, n_out)  # n_units -> n_out
    
        def forward(self, x):
            h1 = F.relu(self.l1(x))
            h2 = F.relu(self.l2(h1))
            return self.l3(h2)