Search code examples
pythonmatplotlibmatplotlib-3d

3D surface is messed up for some angles


I'm trying to plot a 3d graph with MatPlotLib. However, when I render it, it comes out really weird from some angles, but alright from others.

The equation for this particular graph is:

a(x,y)= -0.25(xy)^2 + 1.25xy^(2) + y^(2) - 0.25x^(2)y - 1.75xy  + 2.5x^(2) - 1.5x

Plt has no problem loading simpler equations such as x^(2) + y^(2) or x+y+2xy. It can usually handle it with or without cmap.

x^2 + y^2 with cmap

x^2 + y^2 with cmap

x^2 + y^2 without cmap

x^2 + y^2 without cmap

 x = np.linspace(-6, 6, 50)
    y = np.linspace(-6, 6, 50)
    X, Y = np.meshgrid(x, y)
    Z = f(X, Y)
 
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot_surface(X, Y, Z)
    ax.scatter3D(input_vals[:, 0], input_vals[:, 1], output[:], c=output[:])
    ax.set(xlim=(-6,6), ylim=(-6,6), zlim=(-6,6))
    ax.set_aspect('equal', adjustable='box')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')

The first lines just help create my function and the variables. I'm not too sure where the problem is.

I tried removing rstride and cstride from my ax.plot_surface(X, Y, Z) that didn't work and the graph still looked messy from some angles. One thing that did work was removing cmap. Before the graph would be impossible to discern and would just be a mess of whatever value I put in.

Most of the things I need to plot will be more complex than this. I'm mostly dealing with rectangular equations too. Is there a fix for this or should I try using a different 3D Plotting program?

A bad angle

A good angle

Another good angle

With cmap with glitches

With cmap, without glitches

Still hard to discern.


Solution

  • I think the issue was forcing the aspect and the limits. They're turned off now.

    x = np.linspace(-6, 6, 50)
    y = np.linspace(-6, 6, 50)
    X, Y = np.meshgrid(x, y)
    
    def f(x, y):
        return -0.25*x*y**2 + 1.25*x*y**2 + y**2 - 0.25*x**2 *y - 1.75*x*y + 2.5*x**2 - 1.5*x
        
    Z = f(X, Y)
    
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot_wireframe(X, Y, Z, linewidth=1)
    # ax.plot_surface(X, Y, Z)
    # ax.scatter3D(X, Y, Z, c=Z)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    

    enter image description here