Search code examples

Plotting contour plot of minimum square estimate function in matplotlib

To visualize the gradient descent of my linear regression model, I'm trying to do a contour plot for the following mse function:

import jax.numpy as jnp
import numpy as np

def make_mse(x, t):  
  def mse(w,b): 
    return np.sum(jnp.power( + b - t, 2))/2
  return mse 

where the x and y axes of the plot correspond to w and b parameters.

The x and t are non-relevant for the plot, since the values of x are just being multiplied by a single value of w each time.

I was trying to do the following:

x = np.linspace(-1.0,1.0,500)
t = 5*x + 1

xcoord = np.linspace(-10.0,10.0,50)
ycoord = np.linspace(-10.0,10.0,50)
w1,w2 = np.meshgrid(xcoord,ycoord)

Z = make_mse(x, t)(w1,w2)

However, I get to obvious error for the dot product:

/usr/local/lib/python3.7/dist-packages/jax/_src/lax/ in dot(lhs, rhs, precision, preferred_element_type)
    634   else:
    635     raise TypeError("Incompatible shapes for dot: got {} and {}.".format(
--> 636         lhs.shape, rhs.shape))

TypeError: Incompatible shapes for dot: got (1000, 1) and (50, 50).

Any pythonic efficient way to make a contour plot of this function?


  • You don't need np.sum() since you want the MSE for each grid point individually, not their sum. Also, the dimension of x and the grid must match. The following works:

    import numpy as np
    def make_mse(x, t):  
      def mse(w,b): 
        return np.power( + b - t, 2)
      return mse 
    x = np.linspace(-1.0,1.0,500)
    t = 5*x + 1
    xcoord = np.linspace(-10.0,10.0,500)
    ycoord = np.linspace(-10.0,10.0,500)
    w1,w2 = np.meshgrid(xcoord,ycoord)
    Z = make_mse(x, t)(w1,w2)

    with the following output contour

    enter image description here