Search code examples
for-loopmatrixluadeep-learningtorch

Fast way to initialize a tensor in torch7


I need to initialize a 3D tensor with an index-dependent function in torch7, i.e.

func = function(i,j,k)  --i, j is the index of an element in the tensor
    return i*j*k        --do operations within func which're dependent of i, j
end

then I initialize a 3D tensor A like this:

for i=1,A:size(1) do
    for j=1,A:size(2) do
        for k=1,A:size(3) do
            A[{i,j,k}] = func(i,j,k)
        end
    end
end

But this code runs very slow, and I found it takes up 92% of total running time. Are there any more efficient ways to initialize a 3D tensor in torch7?


Solution

  • See the documentation for the Tensor:apply

    These functions apply a function to each element of the tensor on which the method is called (self). These methods are much faster than using a for loop in Lua.

    The example in the docs initializes a 2D array based on its index i (in memory). Below is an extended example for 3 dimensions and below that one for N-D tensors. Using the apply method is much, much faster on my machine:

    require 'torch'
    
    A = torch.Tensor(100, 100, 1000)
    B = torch.Tensor(100, 100, 1000)
    
    function func(i,j,k) 
        return i*j*k    
    end
    
    t = os.clock()
    for i=1,A:size(1) do
        for j=1,A:size(2) do
            for k=1,A:size(3) do
                A[{i, j, k}] = i * j * k
            end
        end
    end
    print("Original time:", os.difftime(os.clock(), t))
    
    t = os.clock()
    function forindices(A, func)
      local i = 1
      local j = 1
      local k = 0
      local d3 = A:size(3)
      local d2 = A:size(2) 
      return function()
        k = k + 1
        if k > d3 then
          k = 1
          j = j + 1
          if j > d2 then
            j = 1
            i = i + 1
          end
        end
        return func(i, j, k)
      end
    end
    
    B:apply(forindices(A, func))
    print("Apply method:", os.difftime(os.clock(), t))
    

    EDIT

    This will work for any Tensor object:

    function tabulate(A, f)
      local idx = {}
      local ndims = A:dim()
      local dim = A:size()
      idx[ndims] = 0
      for i=1, (ndims - 1) do
        idx[i] = 1
      end
      return A:apply(function()
        for i=ndims, 0, -1 do
          idx[i] = idx[i] + 1
          if idx[i] <= dim[i] then
            break
          end
          idx[i] = 1
        end
        return f(unpack(idx))
      end)
    end
    
    -- usage for 3D case.
    tabulate(A, function(i, j, k) return i * j * k end)