Search code examples
multithreadingluatorchluajit

Torch out of memory in thread when using torch.serialize twice


I'm trying to add a parallel dataloader to the torch-dataframe in order to add torchnet compatibility. I've used the tnt.ParallelDatasetIterator and changed it so that:

  1. A basic batch is loaded outside the threads
  2. The batch is serialized and sent to the thread
  3. In the thread the batch is deserialized and converts the batch data to tensors
  4. The tensors are returned in a table that has the input and target keys in order to match the tnt.Engine setup.

The problem occurs the second time the enque is called with an error: .../torch_distro/install/bin/luajit: not enough memory. I'm currently only working with mnist with an adapted mnist-example. The enque loop now looks like this (with debugging memory output):

-- `samplePlaceholder` stands in for samples which have been
-- filtered out by the `filter` function
local samplePlaceholder = {}

-- The enque does the main loop
local idx = 1
local function enqueue()
  while idx <= size and threads:acceptsjob() do
    local batch, reset = self.dataset:get_batch(batch_size)

    if (reset) then
      idx = size + 1
    else
      idx = idx + 1
    end

    if (batch) then
      local serialized_batch = torch.serialize(batch)

      -- In the parallel section only the to_tensor is run in parallel
      --  this should though be the computationally expensive operation
      threads:addjob(
        function(argList)
          io.stderr:write("\n Start");
          io.stderr:write("\n 1: " ..tostring(collectgarbage("count")))
          local origIdx, serialized_batch, samplePlaceholder = unpack(argList)

          io.stderr:write("\n 2: " ..tostring(collectgarbage("count")))
          local batch = torch.deserialize(serialized_batch)
          serialized_batch = nil

          collectgarbage()
          collectgarbage()

          io.stderr:write("\n 3: " .. tostring(collectgarbage("count")))
          batch = transform(batch)

          io.stderr:write("\n 4: " .. tostring(collectgarbage("count")))
          local sample = samplePlaceholder
          if (filter(batch)) then
            sample = {}
            sample.input, sample.target = batch:to_tensor()
          end
          io.stderr:write("\n 5: " ..tostring(collectgarbage("count")))

          collectgarbage()
          collectgarbage()
          io.stderr:write("\n 6: " ..tostring(collectgarbage("count")))

          io.stderr:write("\n End \n");
          return {
            sample,
            origIdx
          }
        end,
        function(argList)
          sample, sampleOrigIdx = unpack(argList)
        end,
        {idx, serialized_batch, samplePlaceholder}
      )
    end
  end
end

I've sprinkled collectgarbage and also tried to remove any objects not needed. The memory output is rather straight forward:

 Start
 1: 374840.87695312
 2: 374840.94433594
 3: 372023.79101562
 4: 372023.85839844
 5: 372075.41308594
 6: 372023.73632812
 End 

The function that loops the enque is the non-ordered function that is trivial (the memory error is thrown at the second enque and the ):

iterFunction = function()
  while threads:hasjob() do
    enqueue()
    threads:dojob()
    if threads:haserror() then
      threads:synchronize()
    end
    enqueue()

    if table.exact_length(sample) > 0 then
      return sample
    end
  end
end

Solution

  • So the problem was the torch.serialize where the function in the set-up coupled the entire dataset to the function. When adding:

    serialized_batch = nil
    collectgarbage()
    collectgarbage()
    

    The problem was resolved. I further wanted to know what was taking up so much space and the culprit turned out to be that I had defined the function in an environment with a large dataset that got intertwined with the function, massively increasing the size. Here the original definition of the data local

    mnist = require 'mnist'
    local dataset = mnist[mode .. 'dataset']()
    
    -- PROBLEMATIC LINE BELOW --
    local ext_resource = dataset.data:reshape(dataset.data:size(1),
      dataset.data:size(2) * dataset.data:size(3)):double()
    
    -- Create a Dataframe with the label. The actual images will be loaded
    --  as an external resource
    local df = Dataframe(
      Df_Dict{
        label = dataset.label:totable(),
        row_id = torch.range(1, dataset.data:size(1)):totable()
      })
    
    -- Since the mnist package already has taken care of the data
    --  splitting we create a single subsetter
    df:create_subsets{
      subsets = Df_Dict{core = 1},
      class_args = Df_Tbl({
        batch_args = Df_Tbl({
          label = Df_Array("label"),
          data = function(row)
            return ext_resource[row.row_id]
          end
        })
      })
    }
    

    it turns out that removing the line that I highlighted reduces the memory usage from 358 Mb down to 0.0008 Mb! The code that I used for testing the performance was:

    local mem = {}
    table.insert(mem, collectgarbage("count"))
    
    local ser_data = torch.serialize(batch.dataset)
    table.insert(mem, collectgarbage("count"))
    
    local ser_retriever = torch.serialize(batch.batchframe_defaults.data)
    table.insert(mem, collectgarbage("count"))
    
    local ser_raw_retriever = torch.serialize(function(row)
      return ext_resource[row.row_id]
    end)
    table.insert(mem, collectgarbage("count"))
    
    local serialized_batch = torch.serialize(batch)
    table.insert(mem, collectgarbage("count"))
    
    for i=2,#mem do
      print(i-1, (mem[i] - mem[i-1])/1024)
    end
    

    Which produced originally the output:

    1   0.0082607269287109  
    2   358.23344707489 
    3   0.0017471313476562  
    4   358.90182781219 
    

    and after the fix:

    1   0.0094480514526367  
    2   0.00080204010009766 
    3   0.00090408325195312 
    4   0.010146141052246
    

    I tried using the setfenv for the function but it didn't resolve the issue. There is still a performance penalty for sending the serialized data to the thread but the main problem is resolved and without the expensive data retriever the function is considerably smaller.