Search code examples
pythonpandasmatplotlibboxplotcustom-function

Write a function in Python to group by and generate a boxplot in python


Python novice here, In the dataframe below, I need assistance writing a function that does the following:

I: select columns year,state,dept, revenue(mil)

II: boxplot of revenue(mil) ~ dept for each unique state of a unique year, something along the lines of groupby(['year','state])

III: export the graph as a 2-figure per page pdf file

# the dataset
import pandas as pd
import numpy as np
import seaborn as sns

df1={
'xcode':[5001,5001,5250,5250,5425,5425,5610,5610,5910,5910,5010,5010,6110,6110,6135,6135,6220,6220,6550,6550],
'town':["A01","A01","A01","A02","A01","A02","A03","A03","A01","A02","A03","A04","A01","A01","A01","A01","A01","A01","A02","A02"],
'state':["PA","PA","NY","NY","DE","DE","PA","PA","NY","NY","PA","PA","NY","NY","DE","DE","PA","PA","NY","NY"],
'dept':["hlth",'edu','edu','hlth','hlth','edu','hlth','edu','edu','hlth','edu','hlth','hlth','edu','hlth',"hlth",'edu','edu','hlth','hlth'],
'year':[2001,2001,2001,2001,2001,2002,2002,2002,2002,2002,2003,2003,2003,2003,2003,2004,2004,2004,2004,2004],
'revenue(mil)':[112.9,123,124,523.5,112,334,55,449,221.6,332,235,239,235,223,235.6,204,315.5,614,512,514.2],
'tax':[112.0,123,124,523,112,334.5,55,449,221,332,235.6,239,235,223.7,235,204,315,614,512,514.6]
     
}
df1 = pd.DataFrame(df1)
df1

My attempt:

import pandas as pd
import numpy as np
import matplotlib.backends.backend_pdf

def boxplot2pdf(df):
    # select the columns
    df = df[['year','state','dept', 'revenue(mil)']]
    # set pdf page
    pdf = matplotlib.backends.backend_pdf.PdfPages("boxplot2pdf.pdf")
    # set number of graphs per page
    N_plots_per_page = 2
########## Here is where I need help- grouping by year, grouping each year by state and plotting revenue(mil) by dept ###################################################
    #  for each unique year and unique state,boxplot revenue per dept
    for group in groupby(["year","state"]):
        
        g = sns.boxplot(x = dept,y = revenue(mil),data = df[group])
        # the title showing specific year and state
        plt.title("Year:  State: ")
        pdf.savefig(g,fig)
    pdf.close()

#driver code
boxplot2pdf(df1)

Kindly share your full code with comments explaining your approach :)

@Abuzar, your output looks like this

enter image description here


Solution

  • import pandas as pd
    import seaborn as sns
    import matplotlib.backends.backend_pdf
    import matplotlib.pyplot as plt
    
    def boxplot2pdf(df, nFigPerPage):
        years = df['year'].drop_duplicates().sort_values()
        states = df['state'].drop_duplicates().sort_values()
        
        pdf = matplotlib.backends.backend_pdf.PdfPages("boxplot2pdf.pdf")
        
        nFig = 0
        iFigPerPage = 0
        for year in years:
            for state in states: 
                df_year_state = df.loc[(df['state'] == state) & (df['year'] == year),['dept','revenue(mil)']]
                title = "Year: {}, State: {}".format(year, state)
                
                if nFig % nFigPerPage == 0:
                    fig, axs = plt.subplots(nrows=1, ncols=nFigPerPage, squeeze=True )
    
                sns.boxplot(ax=axs[iFigPerPage],  x="dept", y="revenue(mil)", hue='dept', data=df_year_state).set_title(title)
                sns.swarmplot(ax=axs[iFigPerPage], x="dept", y="revenue(mil)", hue='dept', data=df_year_state)
                
                iFigPerPage += 1
                if iFigPerPage % nFigPerPage == 0:
                    iFigPerPage = 0
    
                nFig += 1
                
                if nFig % nFigPerPage == 0:
                    fig.tight_layout()
                    pdf.savefig()
                    plt.close()
                    
        pdf.close()
        
    df={
    'xcode':[5001,5001,5250,5250,5425,5425,5610,5610,5910,5910,5010,5010,6110,6110,6135,6135,6220,6220,6550,6550],
    'town':["A01","A01","A01","A02","A01","A02","A03","A03","A01","A02","A03","A04","A01","A01","A01","A01","A01","A01","A02","A02"],
    'state':["PA","PA","NY","NY","DE","DE","PA","PA","NY","NY","PA","PA","NY","NY","DE","DE","PA","PA","NY","NY"],
    'dept':["hlth",'edu','edu','hlth','hlth','edu','hlth','edu','edu','hlth','edu','hlth','hlth','edu','hlth',"hlth",'edu','edu','hlth','hlth'],
    'year':[2001,2001,2001,2001,2001,2002,2002,2002,2002,2002,2003,2003,2003,2003,2003,2004,2004,2004,2004,2004],
    'revenue(mil)':[112.9,123,124,523.5,112,334,55,449,221.6,332,235,239,235,223,235.6,204,315.5,614,512,514.2],
    'tax':[112.0,123,124,523,112,334.5,55,449,221,332,235.6,239,235,223.7,235,204,315,614,512,514.6]
         
    }
    df1 = pd.DataFrame(df)
    boxplot2pdf(df1, nFigPerPage=2)