Search code examples
luatorch

Lua commands , what do they do?


I am unfamiliar with lua. but the author of the article used lua.

can you help me understand what those two lines do:

what does replicate(x,batch_size) do?

what does x = x:resize(x:size(1), 1):expand(x:size(1), batch_size) do?

original source code can be found here https://github.com/wojzaremba/lstm/blob/master/data.lua


Solution

  • This basically boils down to simple maths and looking up a few functions in the torch manual.

    Ok I'm bored so...

    replicate(x,batch_size) as defined in https://github.com/wojzaremba/lstm/blob/master/data.lua

    -- Stacks replicated, shifted versions of x_inp
    -- into a single matrix of size x_inp:size(1) x batch_size.
    local function replicate(x_inp, batch_size)
       local s = x_inp:size(1)
       local x = torch.zeros(torch.floor(s / batch_size), batch_size)
       for i = 1, batch_size do
         local start = torch.round((i - 1) * s / batch_size) + 1
         local finish = start + x:size(1) - 1
         x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
       end
       return x
    end
    

    This code is using the Torch framework.

    x_inp:size(1) returns the size of dimension 1 of the Torch tensor (a potentially multi-dimensional matrix) x_inp.

    See https://cornebise.com/torch-doc-template/tensor.html#toc_18

    So x_inp:size(1) gives you the number of rows in x_inp. x_inp:size(2), would give you the number of columns...

    local x = torch.zeros(torch.floor(s / batch_size), batch_size)
    

    creates a new two-dimensional tensor filled with zeros and creates a local reference to it, named x The number of rows is calculated from s, x_inp's row count and batch_size. So for your example input it turns out to be floor(11/2) = floor(5.5) = 5.

    The number of columns in your example is 2 as batch_size is 2.

    torch.

    So simply spoken x is the 5x2 matrix

    0 0
    0 0
    0 0
    0 0
    0 0
    

    The following lines copy x_inp's contents into x.

    for i = 1, batch_size do
      local start = torch.round((i - 1) * s / batch_size) + 1
      local finish = start + x:size(1) - 1
      x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
    end
    

    In the first run, start evaluates to 1 and finish to 5, as x:size(1) is of course the number of rows of x which is 5. 1+5-1=5 In the second run, start evaluates to 6 and finish to 10

    So the first 5 rows of x_inp (your first batch) are copied into the first column of x and the second batch is copied into the second column of x

    x:sub(1, x:size(1), i, i) is the sub-tensor of x, row 1 to 5, column 1 to 1 and in the second run row 1 to 5, column 2 to 2 (in your example). So it's nothing more than the first and second columns of x

    See https://cornebise.com/torch-doc-template/tensor.html#toc_42

    :copy(x_inp:sub(start, finish))
    

    copies the elements from x_inp into the columns of x.

    So to summarize you take an input tensor and you split it into batches which are stored in a tensor with one column for each batch.

    So with x_inp

    0
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    

    and batch_size = 2

    x is

    0 5
    1 6
    2 7
    3 8
    4 9
    

    Further:

    local function testdataset(batch_size)
      local x = load_data(ptb_path .. "ptb.test.txt")
      x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
      return x
    end
    

    Is another function that loads some data from a file. This x is not related to the x above other than both being a tensor.

    Let's use a simple example:

    x being

    1
    2
    3
    4
    

    and batch_size = 4

    x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
    

    First x will be resized to 4x1, read https://cornebise.com/torch-doc-template/tensor.html#toc_36

    And then it is expanded to 4x4 by duplicating the first row 3 times.

    Resulting in x being the tensor

    1 1 1 1
    2 2 2 2
    3 3 3 3
    4 4 4 4
    

    read https://cornebise.com/torch-doc-template/tensor.html#toc_49