Search code examples
pythonarraysnumpyslice

multiply numpy ndarray with 1d array along a given axis


It seems I am getting lost in something potentially silly. I have an n-dimensional numpy array, and I want to multiply it with a vector (1d array) along some dimension (which can change!). As an example, say I want to multiply a 2d array by a 1d array along axis 0 of the first array, I can do something like this:

a=np.arange(20).reshape((5,4))
b=np.ones(5)
c=a*b[:,np.newaxis]

Easy, but I would like to extend this idea to n-dimensions (for a, while b is always 1d) and to any axis. In other words, I would like to know how to generate a slice with the np.newaxis at the right place. Say that a is 3d and I want to multiply along axis=1, I would like to generate the slice which would correctly give:

c=a*b[np.newaxis,:,np.newaxis]

I.e. given the number of dimensions of a (say 3), and the axis along which I want to multiply (say axis=1), how do I generate and pass the slice:

np.newaxis,:,np.newaxis

Thanks.


Solution

  • Solution Code -

    import numpy as np
    
    # Given axis along which elementwise multiplication with broadcasting 
    # is to be performed
    given_axis = 1
    
    # Create an array which would be used to reshape 1D array, b to have 
    # singleton dimensions except for the given axis where we would put -1 
    # signifying to use the entire length of elements along that axis  
    dim_array = np.ones((1,a.ndim),int).ravel()
    dim_array[given_axis] = -1
    
    # Reshape b with dim_array and perform elementwise multiplication with 
    # broadcasting along the singleton dimensions for the final output
    b_reshaped = b.reshape(dim_array)
    mult_out = a*b_reshaped
    

    Sample run for a demo of the steps -

    In [149]: import numpy as np
    
    In [150]: a = np.random.randint(0,9,(4,2,3))
    
    In [151]: b = np.random.randint(0,9,(2,1)).ravel()
    
    In [152]: whos
    Variable   Type       Data/Info
    -------------------------------
    a          ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
    b          ndarray    2: 2 elems, type `int32`, 8 bytes
    
    In [153]: given_axis = 1
    

    Now, we would like to perform elementwise multiplications along given axis = 1. Let's create dim_array:

    In [154]: dim_array = np.ones((1,a.ndim),int).ravel()
         ...: dim_array[given_axis] = -1
         ...: 
    
    In [155]: dim_array
    Out[155]: array([ 1, -1,  1])
    

    Finally, reshape b & perform the elementwise multiplication:

    In [156]: b_reshaped = b.reshape(dim_array)
         ...: mult_out = a*b_reshaped
         ...: 
    

    Check out the whos info again and pay special attention to b_reshaped & mult_out:

    In [157]: whos
    Variable     Type       Data/Info
    ---------------------------------
    a            ndarray    4x2x3: 24 elems, type `int32`, 96 bytes
    b            ndarray    2: 2 elems, type `int32`, 8 bytes
    b_reshaped   ndarray    1x2x1: 2 elems, type `int32`, 8 bytes
    dim_array    ndarray    3: 3 elems, type `int32`, 12 bytes
    given_axis   int        1
    mult_out     ndarray    4x2x3: 24 elems, type `int32`, 96 bytes