In a machine learning context, I need to do a per-element multiplication. To do this efficiently, I need to broadcast elements of a 3D tensor in a particular way such that each 2x2 matrix is repeated n times, as shown by the following example with n=2:
import numpy as np
a = np.linspace(1,12,12)
a = a.reshape(3,2,2)
# what to put here?
<some statements>
print a
# result:
[[[ 1. 2.]
[ 3. 4.]]
[[ 1. 2.]
[ 3. 4.]]
[[ 5. 6.]
[ 7. 8.]]
[[ 5. 6.]
[ 7. 8.]]
[[ 9. 10.]
[ 11. 12.]]
[[ 9. 10.]
[ 11. 12.]]]
What statement(s) would do the job?
Thanks!
Here's one with np.repeat
to replicate along the first axis after you have a
as a 3D
array -
N = 2 # replication number
out = np.repeat(a,N,axis=0)
Alternatively, for a 4D
read-only output, we can create a view with np.broadcast_to
and that would be very efficient, as we won't be hogging up any extra memory, like so -
m,n,r = a.shape
out = np.broadcast_to(a[:,None],(m,N,n,r))
# Confirm it's a view
In [32]: np.shares_memory(a, out)
Out[32]: True