To visualize the gradient descent of my linear regression model, I'm trying to do a contour plot for the following mse
function:
import jax.numpy as jnp
import numpy as np
def make_mse(x, t):
def mse(w,b):
return np.sum(jnp.power(x.dot(w) + b - t, 2))/2
return mse
where the x
and y
axes of the plot correspond to w
and b
parameters.
The x
and t
are non-relevant for the plot, since the values of x
are just being multiplied by a single value of w
each time.
I was trying to do the following:
x = np.linspace(-1.0,1.0,500)
t = 5*x + 1
xcoord = np.linspace(-10.0,10.0,50)
ycoord = np.linspace(-10.0,10.0,50)
w1,w2 = np.meshgrid(xcoord,ycoord)
Z = make_mse(x, t)(w1,w2)
However, I get to obvious error for the dot product:
/usr/local/lib/python3.7/dist-packages/jax/_src/lax/lax.py in dot(lhs, rhs, precision, preferred_element_type)
634 else:
635 raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
--> 636 lhs.shape, rhs.shape))
637
638
TypeError: Incompatible shapes for dot: got (1000, 1) and (50, 50).
Any pythonic efficient way to make a contour plot of this function?
You don't need np.sum()
since you want the MSE for each grid point individually, not their sum. Also, the dimension of x
and the grid must match. The following works:
import numpy as np
def make_mse(x, t):
def mse(w,b):
return np.power(x.dot(w) + b - t, 2)
return mse
x = np.linspace(-1.0,1.0,500)
t = 5*x + 1
xcoord = np.linspace(-10.0,10.0,500)
ycoord = np.linspace(-10.0,10.0,500)
w1,w2 = np.meshgrid(xcoord,ycoord)
Z = make_mse(x, t)(w1,w2)
plt.contourf(w1,w2,Z)
with the following output contour