Search code examples
for-loopluaiterationtorch

parallel iteration in lua


I would like to for-loop through multiple tables in parallel in Lua. I could just do:

for i in range(#table1)
  pprint(table1[i])
  pprint(table2[i])
end

But I'd rather something like python's zip:

for elem1, elem2 in zip(table1, table2):
  pprint(elem1)
  pprint(elem2)
end

Is there such a thing in standard Lua (or at least in whatever comes packaged with torch?).


Solution

  • If you want something in Lua that's similar to some Python function, you should look at Penlight first. For this specific case there is the seq.zip function. It seems that Penlight is installed together with Torch, but you can also get it via LuaRocks (which again is bundled with at least one Torch distribution).

    Anyway, the seq.zip function in Penlight only supports zipping two sequences. Here is a version that should behave more like Python's zip, i.e. allowing more (or less) than two sequences:

    local zip
    do
      local unpack = table.unpack or unpack
      local function zip_select( i, var1, ... )
        if var1 then
          return var1, select( i, var1, ... )
        end
      end
    
      function zip( ... )
        local iterators = { n=select( '#', ... ), ... }
        for i = 1, iterators.n do
          assert( type( iterators[i] ) == "table",
                  "you have to wrap the iterators in a table" )
          if type( iterators[i][1] ) ~= "number" then
            table.insert( iterators[i], 1, -1 )
          end
        end
        return function()
          local results = {}
          for i = 1, iterators.n do
            local it = iterators[i]
            it[4], results[i] = zip_select( it[1], it[2]( it[3], it[4] ) )
            if it[4] == nil then return nil end
          end
          return unpack( results, 1, iterators.n )
        end
      end
    end
    
    -- example code (assumes that this file is called "zip.lua"):
    local t1 = { 2, 4, 6, 8, 10, 12, 14 }
    local t2 = { "a", "b", "c", "d", "e", "f" }
    for a, b, c in zip( {ipairs( t1 )}, {ipairs( t2 )}, {io.lines"zip.lua"} ) do
      print( a, b, c )
    end