Search code examples
pythonarraysnumpymemoryview

Numpy Interweave oddly shaped arrays


Alright, here the given data; There are three numpy arrays of the shapes: (i, 4, 2), (i, 4, 3), (i, 4, 2) the i is shared among them but is variable. The dtype is float32 for everything. The goal is to interweave them in a particular order. Let's look at the data at index 0 for these arrays:

[[-208.  -16.]
 [-192.  -16.]
 [-192.    0.]
 [-208.    0.]]

[[ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]]

[[ 0.49609375  0.984375  ]
 [ 0.25390625  0.984375  ]
 [ 0.25390625  0.015625  ]
 [ 0.49609375  0.015625  ]]

In this case, the concatened target array would look something like this:

[-208, -16, 1, 1, 1, 0.496, 0.984, -192, -16, 1, 1, 1, ...]

And then continue on with index 1.

I don't know how to achieve this, as the concatenate function just keeps telling me that the shapes don't match. The shape of the target array does not matter much, just that the memoryview of it must be in the given order for upload to a gpu shader.

Edit: I could achieve this with a few python for loops, but the performance impact would be a problem in this program.


Solution

  • Use np.dstack and flatten with np.ravel() -

    np.dstack((a,b,c)).ravel()
    

    Now, np.dstack is basically stacking along the third axis. So, alternatively we can use np.concatenate too along that axis, like so -

    np.concatenate((a,b,c),axis=2).ravel()
    

    Sample run -

    1) Setup Input arrays :

    In [613]: np.random.seed(1234)
         ...: n = 3
         ...: m = 2
         ...: a = np.random.randint(0,9,(n,m,2))
         ...: b = np.random.randint(11,99,(n,m,2))
         ...: c = np.random.randint(101,999,(n,m,2))
         ...: 
    

    2) Check input values :

    In [614]: a
    Out[614]: 
    array([[[3, 6],
            [5, 4]],
    
           [[8, 1],
            [7, 6]],
    
           [[8, 0],
            [5, 0]]])
    
    In [615]: b
    Out[615]: 
    array([[[84, 58],
            [61, 87]],
    
           [[48, 45],
            [49, 78]],
    
           [[22, 11],
            [86, 91]]])
    
    In [616]: c
    Out[616]: 
    array([[[104, 359],
            [376, 560]],
    
           [[472, 720],
            [566, 115]],
    
           [[344, 556],
            [929, 591]]])
    

    3) Output :

    In [617]: np.dstack((a,b,c)).ravel()
    Out[617]: 
    array([  3,   6,  84,  58, 104, 359,   5,   4,  61,  87, 376, 560,   8,
             1,  48,  45, 472, 720,   7,   6,  49,  78, 566, 115,   8,   0,
            22,  11, 344, 556,   5,   0,  86,  91, 929, 591])