Search code examples
pythonpython-3.xmatplotlibmplcursors

Animate scatter plot according to year in a XLS file


I'm building a simple scatter plot (Life expectancy x GDP per capita) that reads data from a xls file. Here's the code:

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm

#ler a terceira sheet da planilha
data = pd.read_excel('sample.xls', sheet_name=0)
data.head()

plt.scatter(x = data['LifeExpec'],
        y = data['GDPperCapita'],
        s = data['PopX1000'],
        c = data['PopX1000'],
        cmap=cm.viridis,
        edgecolors = 'none',
        alpha = 0.7)

for state in range(len(data['State'])):
    plt.text(x = data['LifeExpec'][state],
         y = data['GDPperCapita'][state],
         s = data['State'][state],
         fontsize = 14)

plt.colorbar()
plt.show()

The xls file: enter image description here

The plot: enter image description here

Now I want to add some data to this xls file from other years, and animate the bubbles so they move and change sizes according the GDP and population numbers of each year. In a silly attempt to do so, I've changed the code to this:

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import mplcursors
from matplotlib.animation import FuncAnimation

data = pd.read_excel('sample.xls', sheet_name=0)
data.head()
uniqueYears = data['Year'].unique()

fig, ax = plt.subplots()

def animate(i):
    for i in uniqueYears:
        ax.scatter(x = data['lifeExpec'],
            y = data['GDPperCapita'],
            s = data['PopX1000']/4,
            c = data['Region'].astype('category').cat.codes,
            cmap=cm.viridis,
            edgecolors = 'none',
            alpha = 0.7)

anim = FuncAnimation(fig, animate)

for state in range(len(data['State'])):
    plt.text(x = data['lifeExpec'][state],
             y = data['GDPperCapita'][state],
             s = data['State'][state],
             fontsize = 10,
             ha = 'center',
             va = 'center')

mplcursors.cursor(hover=True)
plt.draw()
plt.show()

I thought that maybe the way to do this would be to use the animate function to build the chart multiple times, one iteration per year. But I couldn't figure out how to "filter" the rows regarding to that specific year.

Am I too off? Is it even possible to achieve using matplotlib?


Solution

  • Using what Stylianos Nikas and ImportanceOfBeingErnest said as a starting point I made a list with the unique years in the dataframe and used its length as a parameter in FuncAnimation, like this:

    def animate(frames):           
        ax.clear()    
        data = df[df['Ano'] == uniqueYears[frames]]
        ax.scatter(y = data['lifeExpec'],
        x = data['GPDperCapita'],
        s = data['PopX1000']/40000,
        c = data['Region'].astype('category').cat.codes,
        cmap = cm.viridis,
        edgecolors = 'none',
        alpha = 0.5)
    
    anim = FuncAnimation(fig, animate, frames = len(uniqueYears),interval = 200, repeat = False)
    

    To avoid the overlap of frames I simply added ax.clear() to the beginning of the animate function.