Search code examples
pythonmatplotlibmatplotlib-3dbar3d

Creating a 3d Bar Chart in Python


I tried to replicate the following code to create 3D bar charts: https://pythonprogramming.net/3d-bar-chart-matplotlib-tutorial/

The idea behind the code below is to compute the Likelihood ratio of flipping a possibly unfair quarter and penny, source: https://towardsdatascience.com/the-likelihood-ratio-test-463455b34de9

I get the following error: ValueError: shape mismatch: objects cannot be broadcast to a single shape. Mismatch is between arg 0 with shape (5,) and arg 5 with shape (25,).

I don't get why it says that there is a mismatch, as I should have 25 bins stemming from the 2 axis.

Also, the explanation behind how to use of z3 and dz would be greatly appreciated.

    from mpl_toolkits.mplot3d import axes3d
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib import style
    import itertools

    def Likelihood(d,p):
        L = 1
        for i in d:
            if i == 1:
                L = L * p
            else: 
                L = L * (1-p)
        return round(L,4)
    
    lst_prob =  [0, 0.25, 0.5, 0.75, 1]
    lsts_prob = [lst_prob, lst_prob]
    two_coin_matrix = list(itertools.product(*lsts_prob))
    
    quarter_flips = [0,0,1,0,0]
    penny_flips = [1,0,1,1,0]
    all_flips = quarter_flips + penny_flips
    all_flips
    
    z_coin_toss = []
    for probs in two_coin_matrix:
        z_coin_toss.append(round(10000*(Likelihood(quarter_flips, probs[0]) * Likelihood(penny_flips, probs[1]))))
        
    style.use('ggplot')
    
    fig = plt.figure()
    ax1 = fig.add_subplot(111, projection='3d')
    
    x3 = lst_prob
    y3 = lst_prob
    z3 = np.zeros(5)
    
    dx = np.ones(5)
    dy = np.ones(5)
    dz = z_coin_toss
    
    ax1.bar3d(x3, y3, z3, dx, dy, dz)
        
    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
    
    plt.show()

Solution

  • The error you're encountering is due to a mismatch in the lengths of the arrays you're passing to the bar3d function.

    All six arrays (x3, y3, z3, dx, dy, and dz) should have the same length. This is because each element in these arrays corresponds to a specific bar in the 3D bar chart.

    You want to have 25 bars in a square, but you are only passing 5 coordinates for x, y, and z. You can use NumPy's meshgrid to create a 5 x 5 matrix and then flatten it to get the desired list.

    Additionally, dx, dy, and dz are the lengths of the bar in each direction. So in your case, dx and dy should have a length of "0.25" instead of "1".

    Here is the update:

    x3, y3 = np.meshgrid(lst_prob, lst_prob)
    x3 = x3.flatten()
    y3 = y3.flatten()
    z3 = np.zeros(25)
    
    dx = np.ones(25) * 0.25
    dy = np.ones(25) * 0.25
    dz = z_coin_toss
    
    ax1.bar3d(x3, y3, z3, dx, dy, dz)
    
    ax1.set_xlabel('x axis')
    ax1.set_ylabel('y axis')
    ax1.set_zlabel('z axis')
    
    plt.show()