Search code examples
pythonmatplotlibsubplotscatter3d

Simplifying a Matplotlib scatter3D, created in four subplots


Is there a way to simplify the following Python code? The only difference between these four subplots is the viewing angles using the view_init() function.

def plot_example(mydata_dataframe):
    fig = plt.figure(figsize=[15,15])

    #Create subplots

    ax1 = fig.add_subplot(2,2,1, projection='3d')
    ax1.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
    ax1.view_init(0,90)
    ax1.set_xlabel('x', color ='red')
    ax1.set_ylabel('y', color ='red')
    ax1.set_zlabel('z', color ='red')
    ax1.set_xlim(0, 14)
    ax1.set_ylim(-6, 6)
    ax1.set_zlim(0, 8.5)

    ax2 = fig.add_subplot(2,2,2, projection='3d')
    ax2.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
    ax2.view_init(45,0)
    ax2.set_xlabel('x', color ='red')
    ax2.set_ylabel('y', color ='red')
    ax2.set_zlabel('z', color ='red')
    ax2.set_xlim(0, 14)
    ax2.set_ylim(-6, 6)
    ax2.set_zlim(0, 8.5)

    ax3 = fig.add_subplot(2,2,3, projection='3d')
    ax3.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
    ax3.view_init(35,45)
    ax3.set_xlabel('x', color ='red')
    ax3.set_ylabel('y', color ='red')
    ax3.set_zlabel('z', color ='red')
    ax3.set_xlim(0, 14)
    ax3.set_ylim(-6, 6)
    ax3.set_zlim(0, 8.5)
    
    ax4 = fig.add_subplot(2,2,4, projection='3d')
    ax4.scatter3D(data.x, data.y, data.z, c=data.z, cmap='Blues')
    ax4.view_init(20,40)
    ax4.set_xlabel('x', color ='red')
    ax4.set_ylabel('y', color ='red')
    ax4.set_zlabel('z', color ='red')
    ax4.set_xlim(0, 14)
    ax4.set_ylim(-6, 6)
    ax4.set_zlim(0, 8.5)

I tried this:-

import matplotlib.pyplot as plt

def plot_example(mydata_dataframe):
    fig = plt.figure(figsize=[15, 15])

    # Create subplots and plot data
    for i, ax in enumerate([fig.add_subplot(2, 2, i + 1, projection='3d') for i in range(4)]):
        ax.scatter3D(data_df.x, data_df.y, data_df.z, c=data_df.z, cmap='Blues')
        ax.view_init(*[30 * i, 30 * (i + 1)])  # Set different viewing angles
        ax.set_xlabel('x', color='red')
        ax.set_ylabel('y', color='red')
        ax.set_zlabel('z', color='red')

    # Set axis limits for all subplots
    plt.axis([0, 14, -6, 6, 0, 8])

    plt.show()

But got the error "TypeError: the first argument to axis() must be an iterable of the form [xmin, xmax, ymin, ymax]"


Solution

  • First of all, the plt.axis() function only has 4 parameters allowed, instead of the six you have mentioned in your code. Since your z limits are constant, you can add ax.set_zlim(0, 8.5) to your for loop itself. In fact, you can add all ax.set_xlim and ax.set_ylim to the loop as well.

    In summary plt.axis() only works for two dimensions.

    Your code should look something like this:

    import matplotlib.pyplot as plt
    
    def plot3Ddata(data_df):
        fig = plt.figure(figsize=[15, 15])
    
        # Create subplots and plot data
        for i, ax in enumerate([fig.add_subplot(2, 2, i + 1, projection='3d') for i in range(4)]):
            ax.scatter3D(data_df.x, data_df.y, data_df.z, c=data_df.z, cmap='Blues')
            ax.view_init(*[30 * i, 30 * (i + 1)])  # Set different viewing angles
            ax.set_xlabel('x', color='red')
            ax.set_ylabel('y', color='red')
            ax.set_zlabel('z', color='red')
            ax.set_xlim(0, 14)
            ax.set_ylim(-6, 6)
            ax.set_zlim(0, 8.5)
            
    
        plt.show()