Search code examples
pythonplotdatasetmplot3dscatter3d

How to plot 3d scatter of population density in many countries over the years?


I have data set with 160 countries and the population density in each country over a period of 12 years. I want to plot it in 3D scatter but I get this error:

  • ValueError: too many values to unpack (expected 3)

I have created three lists - for "year", "country name", "population density" but it seems I cannot get it right. This is a sample of the data set:

enter image description here

This is my code:

g1 = population_density["year"]
g2 = population_density["country_name"]
g3 = population_density["population_density_(people per sq. km of land area)"]

data = (g1, g2, g3)
colors= list(np.random.choice(range(256), size=160))
groups = ("year", "population density per sq.km", "countries") 

# Create plot
fig = plt.figure(figsize = (10,8))
#ax = Axes3D(fig)
ax = fig.add_subplot(111, projection='3d')
#ax = fig.gca(projection='3d')

for data, color, group in zip(data, colors, groups):
    x, y, z = data
    ax.scatter(x, y, z, alpha=0.8, c=color, edgecolors='none', s=30, label=group)

plt.title('Population Density Over The Years')
plt.legend(loc=2)
plt.show()

In the end, I want to have the scatter plots for all years this 3d plot. Please help!


Solution

  • Instead of ax.scatter(x, y, z, alpha=0.8, c=color, edgecolors='none', s=30, label=group) replace it with ax.scatter(g1, g2, g3, alpha=0.8, c=color, edgecolors='none', s=30, label=group)

    You should replace x with g1, y with g2 and z with g3. According to the documentation in matplotlib scatter 3d, the arguments taken in can be in array form. By using a for loop you are unpacking the values in the list.

    (Edit) After looking at your dataset, you have categorical values in your x and y axis, however scatter plot in 3d requires you to define the cartesian coordinates. Hence what you can do is to set the xticks and yticks.

    You can possibly do that by this code

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    g1 = population_density["year"]
    g2 = population_density["country_name"]
    g3 = population_density["population_density_(people per sq. km of land area)"]
    
    data = (g1, g2, g3)
    colors= list(np.random.choice(range(256), size=len(g1)))
    
    ax.scatter(g1, range(len(g2)), g3, alpha=0.8, c=colors, edgecolors='none', s=30)
    
    ax.set(xticks=range(len(g1)), xticklabels=g1,
           yticks=range(len(g2)), yticklabels=g2,
           zticks=range(len(g3)), zticklabels=g3)
    
    ax.set_xlabel('year')
    ax.set_ylabel('countries')
    ax.set_zlabel('population density per sq.km')
    
    plt.title('Population Density Over The Years')
    plt.legend(loc=2)
    plt.show()