import numpy as np
x1 = np.arange(9.0).reshape((3, 3))
print("x1\n",x1,"\n")
x2 = np.arange(3.0)
print("x2\n",x2)
print(x2.shape,"\n")
print("Here, the shape of x2 is 3 rows by 1 column ")
print("x1@x2\n",x1@x2)
print("")
print("x2@x1 should not be possible\n",x2@x1,"\n"*3)
gives
x1
[[0. 1. 2.]
[3. 4. 5.]
[6. 7. 8.]]
x2
[0. 1. 2.]
(3,)
Here, the shape of x2 is 3 rows by 1 column
x1@x2 =
[ 5. 14. 23.]
x2@x1 should not be possible, BUT
[15. 18. 21.]
Python3 seems to silently convert x2 into (1,3) array so it can be multiplied by x1. Or am I missing some syntax?
The arrays are being broadcasted by Numpy.
To quote the broadcasting documentation:
The term broadcasting describes how numpy treats arrays with different shapes during arithmetic operations. Subject to certain constraints, the smaller array is “broadcast” across the larger array so that they have compatible shapes. Broadcasting provides a means of vectorizing array operations so that looping occurs in C instead of Python. It does this without making needless copies of data and usually leads to efficient algorithm implementations. There are, however, cases where broadcasting is a bad idea because it leads to inefficient use of memory that slows computation.
Add the following line to your code where you explicitly set the shape of x2
to (3,1)
and you will get an error as follows:
import numpy as np
x1 = np.arange(9.0).reshape((3, 3))
print(x1.shape) # new line added
print("x1\n",x1,"\n")
x2 = np.arange(3.0)
x2 = x2.reshape(3, 1) # new line added
print("x2\n",x2)
print(x2.shape,"\n")
print("Here, the shape of x2 is 3 rows by 1 column ")
print("x1@x2\n",x1@x2)
print("")
print("x2@x1 should not be possible\n",x2@x1,"\n"*3)
Output
(3, 3)
x1
[[0. 1. 2.]
[3. 4. 5.]
[6. 7. 8.]]
x2
[[0.]
[1.]
[2.]]
(3, 1)
Here, the shape of x2 is 3 rows by 1 column
x1@x2
[[ 5.]
[14.]
[23.]]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-12-c61849986c5c> in <module>
12 print("x1@x2\n",x1@x2)
13 print("")
---> 14 print("x2@x1 should not be possible\n",x2@x1,"\n"*3)
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 3 is different from 1)