Search code examples
luajointorch

torch/nn - Joining arrays of Tensors element wise


The subject of this question is joining tensors for neural networks with torch/nn and torch/nngraph libraries for Lua. I started coding in Lua a few weeks ago so my experience is very minimal. In the text below, I refer to lua tables as arrays.

Context

I am working with a recurrent neural network for speech recognition. At some point in the network there are N number of arrays of m Tensors.

a = {a1, a2, ..., aM},
b = {b1, b2, ..., bM}, 
... N times

Where ai and bi are tensors and {} represents an array.

What needs to be done is join all those arrays element-wise so that output is an array of M Tensors where output[i] is the result of joining every ith Tensors from the N arrays over the second dimension.

output = {z1, z2, ..., zM}

Example

|| used to represent Tensors

x = {|1 1|, |2 2|}
     |1 1|  |2 2|
     Tensors of size 2x2

y = {|3 3 3|, |4 4 4|}
     |3 3 3|  |4 4 4|
     Tensors of size 2x3
        |
        | Join{x,y}
        \/
z = {|1 1 3 3 3|, |2 2 4 4 4|}
     |1 1 3 3 3|  |2 2 4 4 4|
     Tensors of size 2x5

So the first Tensor of x of size 2x2 was joined with the first Tensor of y of size 2x3 over the second dimension and same thing for second Tensor of each array resulting in z an array of Tensors 2x5.

Problem

Now this is a basic concatenation, but I can't seem to find a module in the torch/nn library that would allow me to do that. I could write my own module of course, but if an already existing module does it then I would rather go with that.

The only existing module I know that joins table is (obviously) JoinTable. It takes an array of Tensors and joins them together. I want to join arrayS of tensors element-wise.

Also, as we are feeding input to our network, the number of Tensors in the N arrays varies, so m from the context above is not constant.

Idea

What I thought I could do in order to use the module JoinTable is convert my arrays into Tensors instead and then JoinTable on the converted N Tensors. But then again I would need a module that does such a conversion and and another one to convert back to an array in order to feed it to the next layers of the network.

Last resort

Write a new module that iterates over all given arrays and concatenates element-wise. Of course it's do-able, but the whole purpose of this post is to find a way to avoid writing smelly modules. It seems weird to me that such a module doesn't already exist.


Conclusion

I finally decided to do as I wrote in Last resort. I wrote a new module that iterates over all given arrays and concatenates element-wise.

Though, the answer given by @fmguler does the same without having to write a new module.


Solution

  • You can do it with nn.SelectTable and nn.JoinTable like this;

    require 'nn'
    
    x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
    y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}
    
    res = {}
    res[1] = nn.JoinTable(2):forward({nn.SelectTable(1):forward(x),nn.SelectTable(1):forward(y)})
    res[2] = nn.JoinTable(2):forward({nn.SelectTable(2):forward(x),nn.SelectTable(2):forward(y)})
    
    print(res[1])
    print(res[2])
    

    If you want this to be done in a module, wrap it in nnGraph;

    require 'nngraph'
    
    x = {torch.Tensor{{1,1},{1,1}}, torch.Tensor{{2,2},{2,2}}}
    y = {torch.Tensor{{3,3,3},{3,3,3}}, torch.Tensor{{4,4,4},{4,4,4}}}
    
    xi = nn.Identity()()
    yi = nn.Identity()()
    res = {}
    --you can loop over columns here>>
    res[1] = nn.JoinTable(2)({nn.SelectTable(1)(xi),nn.SelectTable(1)(yi)})
    res[2] = nn.JoinTable(2)({nn.SelectTable(2)(xi),nn.SelectTable(2)(yi)})
    module = nn.gModule({xi,yi},res)
    
    --test like this
    result = module:forward({x,y})
    print(result)
    print(result[1])
    print(result[2])
    
    --gives the result
    th> print(result)
    {
      1 : DoubleTensor - size: 2x5
      2 : DoubleTensor - size: 2x5
    }
    
    th> print(result[1])
     1  1  3  3  3
     1  1  3  3  3
    [torch.DoubleTensor of size 2x5]
    
    th> print(result[2])
     2  2  4  4  4
     2  2  4  4  4
    [torch.DoubleTensor of size 2x5]