Search code examples
pythonmatplotlibvisualizationlegend-properties

Optimsing x-axis according to data range - plt.yticks` between the horizontal bars


I am trying to generate population plot using code below. I re used some of code I found. However, I dont know I can optimise legends according to the data range I have. I mean so I should have nice plot which is squished because of wrong x.axis limit.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

age = np.array(["0-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70-79", "80-89", '90-99',
                "100-109", "110-119", "120-129", "130-139", "140-150", ">150"])
m = np.array([811, 34598, 356160, 381160, 243330, 206113, 128549, 60722, 8757, 1029, 1033, 891, 1803, 62, 92, 764])
f = np.array(
    [612, 101187, 904717, 841066, 503661, 421678, 248888, 95928, 10289, 1444, 1360, 1377, 1699, 119, 173, 1655])
x = np.arange(age.size)
tick_lab = ['3M', '2M', '1M', '1M', '2M', '3M']
tick_val = [-3000000, -2000000, -1000000, 1000000, 2000000, 3000000]


plt.figure(figsize=(16, 8), dpi=80)
def plot_pyramid():
    plt.barh(x, -m, alpha=.75, height=.75, left=-shift, align='center' , color="deepskyblue")
    plt.barh(x, f, alpha=.75, height=.75, left = shift, align='center', color="pink")
    plt.yticks([])
    plt.xticks(tick_val, tick_lab)
    plt.grid(b=False)
    plt.title("Population Pyramid")
    for i, j in enumerate(age):
        if i == 0 or i==1:
            plt.text(-150000, x[i] - 0.2, j, fontsize=14)
        else:    
            plt.text(-230000, x[i] - 0.2, j, fontsize=14)


if __name__ == '__main__':
    plot_pyramid()

Any help will be appreciated

Thanks in advance


Solution

  • Here are a few ideas to tackle the issues mentioned:

    • Instead of putting the xticks at fixed positions, let matplotlib choose automatically where to put the ticks.
    • A custom tick formatter could display the numbers either with M or with K depending on their size.
    • The labels for the age ranges could be placed centered instead of left aligned.
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.ticker import FuncFormatter
    
    age = np.array(["0-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70-79", "80-89", '90-99',
                    "100-109", "110-119", "120-129", "130-139", "140-150", ">150"])
    m = np.array([811, 34598, 356160, 381160, 243330, 206113, 128549, 60722, 8757, 1029, 1033, 891, 1803, 62, 92, 764])
    f = np.array([612, 101187, 904717, 841066, 503661, 421678, 248888, 95928, 10289, 1444, 1360, 1377, 1699, 119, 173, 1655])
    x = np.arange(age.size)
    
    def k_and_m_formatter(x, pos):
        if x == 0:
            return ''
        x = abs(x)
        if x > 900000:
            return f'{x / 1000000: .0f} M'
        elif x > 9000:
            return f'{x / 1000: .0f} K'
        else:
            return f'{x : .0f}'
    
    def plot_pyramid():
        fig, ax = plt.subplots(figsize=(16, 8), dpi=80)
        shift = 0
        ax.barh(x, -m, alpha=.75, height=.75, left=-shift, align='center' , color="deepskyblue")
        ax.barh(x, f, alpha=.75, height=.75, left = shift, align='center', color="pink")
        ax.set_yticks([])
        ax.xaxis.set_major_formatter(FuncFormatter(k_and_m_formatter))
        ax.grid(b=False)
        ax.set_title("Population Pyramid")
        for i, age_span in enumerate(age):
            ax.text(0, x[i], age_span, fontsize=14, ha='center', va='center')
    
    plot_pyramid()
    

    resulting plot

    Optionally the x-axis could be log scaled (ax.xscale('symlog')) to make the small values more visible.

    The code in the question uses a variable shift, but doesn't give it a value. Giving it a value different from 0 will result in a gap (to place the labels?), but would also place the ticks on the wrong positions.

    For an alternative, see e.g. How to build a population pyramid? or Using Python libraries to plot two horizontal bar charts sharing same y axis for an example with two separate subplots.