Search code examples
pythonmatplotliblinear-regressiontraining-data

How to draw a perpendicular line from each point of a training set to surface plane in matplotlib


I'm using matplotlib. I have a following linear regression model with a surface plane and training data set.

I need to draw orthogonal distance from each data point to the surface plane that would look similar to this: enter image description here

Here is the code snippet that I have:

nx, ny = (100, 100)

x1 = np.linspace(-3, 10.0, nx)
x2 = np.linspace(0, 15.0, ny)

x_plane, y_plane = np.meshgrid(x1, x2)

XY = np.stack((x_plane.ravel(), y_plane.ravel()),axis =1)

z_plane = np.array([normal_equation(x,y) for x,y in XY]).reshape(x_plane.shape)

fig = plt.figure(figsize=(10, 8))
ax  = fig.gca(projection = '3d')

ax.scatter(X2, X1, Y, color='r')
ax.plot_surface(x_plane, y_plane, z_plane, color='b', alpha=0.4)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('y')
ax.set_zlim(-10, 5)

Any help would be very appreciated.


Solution

  • Some simple mathematical facts we can use to solve this problem:

    • The cross product of two vectors on the plane is a vector perpendicular to the plane.
    • The dot product of two vectors measures the distance that each vector travels along the same direction as the other vector.

    First, we can find a vector perpendicular to the plane using the following code:

    perpendicular = np.cross(
      (0, 1, normal_equation(0, 1) - normal_equation(0, 0)),
      (1, 0, normal_equation(1, 0) - normal_equation(0, 0))
    )
    normal = perpendicular / np.linalg.norm(perpendicular)
    

    (Note: we assumed here that the plane is not vertical which it shouldn't be in linear regression)

    Second, we need to trace back each point along this normal vector back to the plane.

    plane_point = np.array([0, 0, normal_equation(0, 0)])
    dot_prods = [
      np.dot(np.array(u) - plane_point, normal)
      for u in zip(X2, X1, Y)
    ]
    closest_points = [
      np.array([X2[i], X1[i], Y[i]]) - normal * dot_prods[i]
      for i in range(len(Y))
    ]
    

    Finally, we can draw connections between each of these points.

    for i in range(len(Y)):
      ax.plot(
        [closest_points[i][0], X2[i]],
        [closest_points[i][1], X1[i]],
        [closest_points[i][2], Y[i]],
        'k-'
    )
    

    I hope this helps!