Search code examples
pythonnumpybroadcasting

Numpy broadcasting in z direction


In a machine learning context, I need to do a per-element multiplication. To do this efficiently, I need to broadcast elements of a 3D tensor in a particular way such that each 2x2 matrix is repeated n times, as shown by the following example with n=2:

import numpy as np

a = np.linspace(1,12,12)
a = a.reshape(3,2,2)

# what to put here?
<some statements>

print a

# result:
[[[  1.   2.]
  [  3.   4.]]

 [[  1.   2.]
  [  3.   4.]]

 [[  5.   6.]
  [  7.   8.]]

 [[  5.   6.]
  [  7.   8.]]

 [[  9.  10.]
  [ 11.  12.]]

 [[  9.  10.]
  [ 11.  12.]]]

What statement(s) would do the job?

Thanks!


Solution

  • Here's one with np.repeat to replicate along the first axis after you have a as a 3D array -

    N = 2 # replication number
    out = np.repeat(a,N,axis=0)
    

    Alternatively, for a 4D read-only output, we can create a view with np.broadcast_to and that would be very efficient, as we won't be hogging up any extra memory, like so -

    m,n,r = a.shape
    out = np.broadcast_to(a[:,None],(m,N,n,r))
    
    # Confirm it's a view
    In [32]: np.shares_memory(a, out)
    Out[32]: True