Search code examples
pythonarraysnumpylua

Analogue of numpy indices for lua multidimensional array


In numpy any ndarray can be accessed by specifying a tuple of integers. So if we build an array with arbitrary dimensions and populate it using a formula that is based on the indices of an element, it is straightforward to access an element by supplying a tuple with the element's indices:

import numpy as np

a = np.zeros((5,4,3,2))
def ivalue(index):
  return (index[0]-index[3])/(index[1]+index[2]+1)

for index, _ in np.ndenumerate(a):
    a[index] = ivalue(index)

index = (0,3,2,1)

print(a[index])  // -1/6

In lua it is relatively easy to construct the same array. For example

a = {}
for i=1, 5 do
    a[i] = {}
    for j=1, 4 do
        a[i][j] = {}
        for k=1, 3 do
            a[i][j][k] = {}
            for l=1, 2 do
                a[i][j][k][l] = (i-l)/(j+k-1)         
            end
        end
    end
end

io.write(a[1][4][3][2]) -- -1/6

But now suppose we had to write a program where given the indices we want to return a value; in numpy this is straightforward since (as long as index matches the array structure) we can write a[index] and we are done. But in lua I cannot think of a way where given index={1,4,3,2} I could retrieve the value stored. Native python has the same issue so I was thinking perhaps there is an extension of lua that deals with multidimensional array indices.


Solution

  • I don't know lua, but can give you a function that works with the equivalent nested python list.

    Make the array with broadcasted arrays:

    In [303]: res = (np.arange(5)[:,None,None,None] -
        np.arange(2))/(np.arange(4)[:,None,None] +
        np.arange(3)[:,None]+1)
    In [304]: res.shape
    Out[304]: (5, 4, 3, 2)
    

    Make a nested list from that:

    In [305]: lres = res.tolist()    
    In [306]: len(lres), len(lres[0][0])
    Out[306]: (5, 3)
    

    An iterative function to apply successive indices:

    def foo(a, idx):
        res = a
        if len(idx)==0:
            return res
        for i in idx:
            res = res[idx[0]]
            idx = idx[1:]
        return res
    

    testing with index:

    In [311]: index = (4,3,2,1)
    

    The array version:

    In [312]: res[index]
    Out[312]: 0.5
    

    trying the tuple on the nested list:

    In [313]: lres[index]
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    Cell In[313], line 1
    ----> 1 lres[index]
    
    TypeError: list indices must be integers or slices, not tuple
    

    But the function:

    In [314]: foo(lres, index)
    Out[314]: 0.5
    

    it also works on the array (but not quite as fast):

    In [315]: foo(res, index)
    Out[315]: 0.5
    

    A recursive equivalent:

    def foor(a, idx):
        if len(idx)==0:
            return a
        else:
            return foor(a[idx[0]], idx[1:])
    
    In [317]: foor(lres, index)
    Out[317]: 0.5  
    

    From the little I've seen about lua you should be able to translate one or both without too much trouble. Looks like its table is similar to list.

    There are about 300 questions/answers on SO with searching for [lua] array. So people have been working with some sort of arrays in lua - though in many cases that may be just another name for table.

    In python deeply nested lists are not common. One level of nesting is common enough. There are even tricks for 'transposing` such a list

    In [318]: alist = [[1,2,3],[4,5,6]]; alist
    Out[318]: [[1, 2, 3], [4, 5, 6]]
    
    In [319]: list(zip(*alist))
    Out[319]: [(1, 4), (2, 5), (3, 6)]
    

    Under the covers, numpy stores all the values in a flat (1d) array ('C'), and uses shape and strides to treat it as a multidimensional array.

    An alternative to creating a deeply nested list, would be to make a flat one, and use a bit of math to convert index to a flat equivalent.

    numpy has a function to calculate a flat index from a multidimensional one:

    In [320]: np.ravel_multi_index(index, res.shape)
    Out[320]: 119    
    In [321]: res.ravel()[119]
    Out[321]: 0.5
    

    In this direction the calculation is easy, it's the reverse that's a bit trickier.

    In [324]: res.strides
    Out[324]: (192, 48, 16, 8)
    In [325]: np.multiply(index, res.strides).sum()/8
    Out[325]: 119.0