Search code examples
pythonnumpyconcatenation

Understanding the syntax of numpy.r_() concatenation


I read the following in the numpy documentation for the function r_:

A string integer specifies which axis to stack multiple comma separated arrays along. A string of two comma-separated integers allows indication of the minimum number of dimensions to force each entry into as the second integer (the axis to concatenate along is still the first integer).

and they give this example:

>>> np.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, dim>=2
array([[1, 2, 3],
       [4, 5, 6]])

I don't follow, what does exactly the string '0,2' instruct numpy to do?

Other than the link above, is there another site with more documentation about this function?


Solution

  • 'n,m' tells r_ to concatenate along axis=n, and produce a shape with at least m dimensions:

    In [28]: np.r_['0,2', [1,2,3], [4,5,6]]
    Out[28]: 
    array([[1, 2, 3],
           [4, 5, 6]])
    

    So we are concatenating along axis=0, and we would normally therefore expect the result to have shape (6,), but since m=2, we are telling r_ that the shape must be at least 2-dimensional. So instead we get shape (2,3):

    In [32]: np.r_['0,2', [1,2,3,], [4,5,6]].shape
    Out[32]: (2, 3)
    

    Look at what happens when we increase m:

    In [36]: np.r_['0,3', [1,2,3,], [4,5,6]].shape
    Out[36]: (2, 1, 3)    # <- 3 dimensions
    
    In [37]: np.r_['0,4', [1,2,3,], [4,5,6]].shape
    Out[37]: (2, 1, 1, 3) # <- 4 dimensions
    

    Anything you can do with r_ can also be done with one of the more readable array-building functions such as np.concatenate, np.row_stack, np.column_stack, np.hstack, np.vstack or np.dstack, though it may also require a call to reshape.

    Even with the call to reshape, those other functions may even be faster:

    In [38]: %timeit np.r_['0,4', [1,2,3,], [4,5,6]]
    10000 loops, best of 3: 38 us per loop
    In [43]: %timeit np.concatenate(([1,2,3,], [4,5,6])).reshape(2,1,1,3)
    100000 loops, best of 3: 10.2 us per loop