Search code examples
pythonnumpyconcatenation

np.concatenate: What does it do for a single ndarray as input


I am trying to understand what exactly happens in the following snippet?

>>> a = np.random.randn(3,3,3)
>>> a
array([[[ 1.17565688,  0.58223235, -0.41813242],
        [-0.16933573,  0.24205104,  1.37286476],
        [-0.58120365,  0.32970027, -0.1521039 ]],

       [[ 1.18393152,  0.00254526,  1.67234901],
        [ 0.72494527,  0.72414755,  0.32974478],
        [-1.2290148 ,  1.18013258, -0.61498214]],

       [[-0.38574517, -0.46385622,  0.06616913],
        [ 0.26560153,  0.61720524,  0.03528806],
        [ 0.66292143, -0.57724826, -0.33810831]]])
>>> np.concatenate(a, 1)
array([[ 1.17565688,  0.58223235, -0.41813242,  1.18393152,  0.00254526,
         1.67234901, -0.38574517, -0.46385622,  0.06616913],
       [-0.16933573,  0.24205104,  1.37286476,  0.72494527,  0.72414755,
         0.32974478,  0.26560153,  0.61720524,  0.03528806],
       [-0.58120365,  0.32970027, -0.1521039 , -1.2290148 ,  1.18013258,
        -0.61498214,  0.66292143, -0.57724826, -0.33810831]])

What operations are exactly happening? Is it splitting array on the given dimension and then concatenating?

Thank You!


Solution

  • The signature

    numpy.concatenate((a1, a2, ...), axis=0, out=None) 
    

    implies that the input is 'unpacked' on the first axis.

    To illustrate:

    In [40]: arr = np.arange(24).reshape(2,3,4)                                                          
    In [41]: a1,a2 = arr                                                                                 
    In [42]: a1                                                                                          
    Out[42]: 
    array([[ 0,  1,  2,  3],
           [ 4,  5,  6,  7],
           [ 8,  9, 10, 11]])
    In [43]: a2                                                                                          
    Out[43]: 
    array([[12, 13, 14, 15],
           [16, 17, 18, 19],
           [20, 21, 22, 23]])
    In [44]: np.concatenate(arr, axis=1)                                                                 
    Out[44]: 
    array([[ 0,  1,  2,  3, 12, 13, 14, 15],
           [ 4,  5,  6,  7, 16, 17, 18, 19],
           [ 8,  9, 10, 11, 20, 21, 22, 23]])
    In [45]: np.concatenate((a1,a2), axis=1)                                                             
    Out[45]: 
    array([[ 0,  1,  2,  3, 12, 13, 14, 15],
           [ 4,  5,  6,  7, 16, 17, 18, 19],
           [ 8,  9, 10, 11, 20, 21, 22, 23]])