Search code examples
pythonmatplotlibplot

Create a bar chart in python grouping the x axis by 2 variables


I need to create a bar chart using python, where I can group the X axis by 2 variables. The data I need to graph is the following

Country City Number of Universities
Germany Berlin 30
Germany Munich 20
Germany Hamburg 10
France Paris 40
France Marseille 5
France Lyon 10
France Nice 5
Spain Madrid 25
Spain Barcelona 15
Spain Valencia 10
Spain Seville 7
Denmark Copenhagen 10
Denmark Aarhus 5
Italy Rome 20
Italy Milan 15
Italy Naples 8
Italy Florence 7
Austria Vienna 12
Austria Salzburg 4

I share the code to create it below:

import pandas as pd

# Create a dictionary with the data
data = {
    'Country': ['Germany', 'Germany', 'Germany', 'France', 'France', 'France', 'France',
                'Spain', 'Spain', 'Spain', 'Spain', 'Denmark', 'Denmark', 'Italy',
                'Italy', 'Italy', 'Italy', 'Austria', 'Austria'],
    'City': ['Berlin', 'Munich', 'Hamburg', 'Paris', 'Marseille', 'Lyon', 'Nice',
             'Madrid', 'Barcelona', 'Valencia', 'Seville', 'Copenhagen', 'Aarhus',
             'Rome', 'Milan', 'Naples', 'Florence', 'Vienna', 'Salzburg']
}

# Convert the dictionary into a pandas DataFrame
df = pd.DataFrame(data)

# Count the number of universities per city and add it as a new column
universities_count = [30, 20, 10, 40, 5, 10, 5, 25, 15, 10, 7, 10, 5, 20, 15, 8, 7, 12, 4]
df['Number of Universities'] = universities_count

# Show the DataFrame
print(df)

The objective is to generate the following graph: enter image description here

in advance, thank you very much for your help


Solution

  • Depending on how closely you want to replicate this, you'll need an approach that can calculate the necessary height/placement of your group labels and tick lines.

    I wrote a helper function to do this that calculates the current height (in points) of the xaxis of an inputted Axes.

    After that, we effectively need 3 sets of Axes.xaxis to work with:

    1. For the individual 'City' level ticks
    2. For the grouped 'Country' level ticks
    3. For the lines that span from the Axes down to the bottom of the 'Country' level labels.
    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame({
        'Country': [
            'Germany', 'Germany', 'Germany', 'France', 'France', 'France', 'France',
            'Spain', 'Spain', 'Spain', 'Spain', 'Denmark', 'Denmark', 'Italy',
            'Italy', 'Italy', 'Italy', 'Austria', 'Austria'
        ],
        'City': [
            'Berlin', 'Munich', 'Hamburg', 'Paris', 'Marseille', 'Lyon', 'Nice',
             'Madrid', 'Barcelona', 'Valencia', 'Seville', 'Copenhagen', 'Aarhus',
             'Rome', 'Milan', 'Naples', 'Florence', 'Vienna', 'Salzburg'
        ],
        'Number of Universities': [
            30, 20, 10, 40, 5, 10, 5, 25, 15, 10, 7, 10, 5, 20, 15, 8, 7, 12, 4
        ]
    })
    
    plot_df = df.sort_values(['Country', 'City'])
    
    ###
    
    def get_xaxis_height(ax):
        height = 0
        axes = [ax] + ax.child_axes
        for ax in axes:
            height += ax.xaxis.get_tightbbox().height
            height += ax.xaxis.get_tick_params()['pad']
        return height * 72 / fig.dpi
    
    from matplotlib import pyplot as plt
    plt.rc('font', size=12)
    
    ## Create base chart
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.spines[['left', 'top', 'right']].set_visible(False) # turn off all spines
    
    bc = ax.bar('City', 'Number of Universities', data=plot_df, width=.6)
    ax.bar_label(bc)
    ax.xaxis.set_tick_params(
        rotation=90,
        bottom=False,
        length=0,
        pad=1,   # adjust pad to move individual labels further/closer to bottom spine
    )
    ax.yaxis.set_tick_params(left=False)
    
    ## Add group labels underneath existing rotated labels
    label_locs = (
        plot_df.assign(tick_loc=np.arange(len(plot_df)))
        .groupby('Country')['tick_loc']
        .mean()
    )
    ax_bottom = get_xaxis_height(ax)
    group_label_ax = ax.secondary_xaxis(location='bottom')
    group_label_ax.set_xticks(label_locs, labels=label_locs.index, ha='center')
    group_label_ax.tick_params(
        bottom=False,
        pad=10,  # adjust pad to move your group labels further/closer to the individual labels
        length=ax_bottom
    )
    
    ## add long tick lines where needed
    line_locs = (
        plot_df.assign(tick_loc=np.arange(len(plot_df)))
        .loc[lambda d:
            d['Country'] != d['Country'].shift(), 'tick_loc'
        ]
        - 0.5
    ).tolist()
    line_locs += [len(df) - .5]
    
    ax_bottom = get_xaxis_height(ax)
    tickline_ax = ax.secondary_xaxis(location='bottom')
    tickline_ax.set_xticks(line_locs)
    tickline_ax.tick_params(labelbottom=False, length=ax_bottom, pad=0)
    ax.set_xlim(-.5, len(ax.containers[0]) - .5)
    
    ## adjust spine & tick colors
    ax.spines['bottom'].set_color('gainsboro')
    tickline_ax.xaxis.set_tick_params(
        color='gainsboro', 
        labelcolor='black', 
        width=ax.spines['bottom'].get_linewidth() * 2
    )
    
    ## adjust y-ticks to be multiples of 5
    from matplotlib.ticker import MultipleLocator
    ax.yaxis.set_major_locator(MultipleLocator(5))
    ax.margins(y=.2)
    
    fig.tight_layout()
    plt.show()
    

    enter image description here