Search code examples
pythonpandasmatplotlibplotmulti-index

How to create a plot for a multiindex dataframe


I need to make plots (y = 'total_sales_sum', x = 'year_of_release') for each gaming platform. For this I had used pivot table, hence got multiindex dataframe.

data_recent_decade=data.query('year_of_release>=2006').pivot_table(index=['platform','year_of_release'],values=['total_sales'], aggfunc=['sum'])
data_recent_decade.columns=['total_sales_sum']
data_recent_decade.info()
for platform in data_recent_decade:
    data_recent_decade.plot(y='total_sales_sum', marker='o',grid=True,figsize=(13,4))
    plt.title(platform)
    plt.show()

This is the final dataframe:

dataframe with multiindex

This is data_recent_decade.info()

<class 'pandas.core.frame.DataFrame'> MultiIndex: 101 entries, (3DS, 2011.0) to (XOne, 2016.0) Data columns (total 1 columns): total_sales_sum 101 non-null float64 dtypes: float64(1) memory usage: 1.4+ KB

My broken plot:

Broken graph I get

How to make a plot for each platform?


Solution

  • You can loop over pandas.MultiIndex with:

    for date, new_df in df.groupby(level = 0)
    

    Complete Code

    import pandas as pd
    import matplotlib.pyplot as plt
    
    df = pd.DataFrame({'platform': ['3DS', '3DS', '3DS', '3DS', '3DS', 'XB', 'XBOne', 'XBOne', 'XBOne', 'XBOne'],
                       'year_of_release': [2011, 2012, 2013, 2014, 2015, 2008, 2013, 2014, 2015, 2016],
                       'total_sales_sum': [60.53, 51.01, 56.32, 43.07, 27.21, 0.18, 18.96, 54.07, 59.92, 25.82]})
    df = df.set_index(['platform', 'year_of_release'])
    
    
    fig, ax = plt.subplots()
    
    for date, new_df in df.groupby(level = 0):
        ax.plot(new_df.index.get_level_values('year_of_release').values,
                new_df['total_sales_sum'],
                label = new_df.index.get_level_values('platform').values[0],
                marker = 'o',
                linestyle = '-')
    
    ax.legend(frameon = True)
    
    plt.show()
    

    enter image description here


    As an alternative, you can do it without any loop using seaborn.lineplot:

    fig, ax = plt.subplots()
    
    sns.lineplot(ax = ax,
                 data = df,
                 x = df.index.get_level_values('year_of_release'),
                 y = df['total_sales_sum'],
                 hue = df.index.get_level_values('platform'),
                 marker = 'o')
    
    plt.show()
    

    enter image description here