Search code examples
deep-learningtorch

Build network with shortcut using torch


I now have a network with 2 inputs X and Y.

X concatenates Y and then pass to network to get result1. And at the same time X will concat result1 as a shortcut.

It's easy if there is only one input.

branch = nn.Sequential()
branch:add(....) --some layers
net = nn.Sequential()
net:add(nn.ConcatTable():add(nn.Identity()):add(branch))
net:add(...)

But when it comes to two inputs I don't actually know how to do it? Besides, nngraph is not allowed.Does any one know how to do it?


Solution

  • You can use the table modules, have a look at this page: https://github.com/torch/nn/blob/master/doc/table.md

    net = nn.Sequential()
    triple = nn.ParallelTable()
    duplicate = nn.ConcatTable()
    duplicate:add(nn.Identity())
    duplicate:add(nn.Identity())
    triple:add(duplicate)
    triple:add(nn.Identity())
    net:add(triple)
    net:add(nn.FlattenTable())
    -- at this point the network transforms {X,Y} into {X,X,Y}
    separate = nn.ConcatTable()
    separate:add(nn.SelectTable(1))
    separate:add(nn.NarrowTable(2,2))
    net:add(separate)
    -- now you get {X,{X,Y}}
    parallel_XY = nn.ParallelTable()
    parallel_XY:add(nn.Identity()) -- preserves X
    parallel_XY:add(...) -- whatever you want to do from {X,Y}
    net:add(parallel)
    parallel_Xresult = nn.ParallelTable()
    parallel_Xresult:add(...)  -- whatever you want to do from {X,result}
    net:add(parallel_Xresult)
    
    output = net:forward({X,Y})
    

    The idea is to start with {X,Y}, to duplicate X and to do your operations. This is clearly a bit complicated, nngraph is supposed to be here to do that.