Search code examples
pythonpandasmatplotlibchartssubplot

Python matplotlib subplots: putting horizontal bar charts side by side and unwanted numerical labels on axes


Consider the following fake data:

fake_data = pd.DataFrame(
    {"index": ["A", "B", "C"], 0: [0.1, 0.2, 0.3], 1: [2, 4, 6], 2: [1, 3, 5]}
)
fake_data.set_index("index")

I want to plot a sequence of horizontal bar charts, one row by three columns. The y axis of each bar chart contains the categorical variables A, B and C. The x axis of each bar chart should be numerical.

I have tried the following:

fig, ax = plt.subplots(1, 3, sharex = True, sharey = True, figsize = (6, 8), frameon = False)
ax1 = fig.add_subplot(1,3,1)
fake_data[0].plot.barh()
ax2 = fig.add_subplot(1,3,2)
fake_data[1].plot.barh()
ax3 = fig.add_subplot(1,3,3)
fake_data[2].plot.barh()

The output is:

The output of the above code

I've tried adjusting the axes labels and ticks by modifying the above code:

fig, ax = plt.subplots(1, 3, sharex = True, sharey = True, figsize = (6, 8), frameon = False)
ax1 = fig.add_subplot(1,3,1)
fake_data[0].plot.barh()
ax1.set_yticks([])
ax2 = fig.add_subplot(1,3,2)
fake_data[1].plot.barh()
ax2.set_xticks([])
ax3 = fig.add_subplot(1,3,3)
fake_data[2].plot.barh()

The output is:

The output of the preceding block of code

As you can see, the categorical labels are now gone.

My questions are:

  1. How can I get rid of the numerical ticks / labels on the y axis on the leftmost horizontal bar chart, given that the y axis is supposed to have a categorical variable?

  2. How can I get rid of the labels and ticks and the y-axis name "index" on the second and third bar charts?

  3. Similarly to (1), how can I get rid of the numerical ticks / labels on the x axis that are overlapping with the correct ticks / labels?

  4. How can I make all the x axes of the horizontal bar charts on the same scale, such that the first (leftmost) horizontal bar chart is of the order of a tenth of the second and third bar charts?


Solution

  • You've already created your axes using plt.subplots. You can pass them to each pd.Series.plot method and it will fix your problems:

    fig, (ax1, ax2, ax3) = plt.subplots(
        1, 3, sharex=True, sharey=True, figsize=(6, 8), frameon=False
    )
    fake_data[0].plot.barh(ax=ax1)
    fake_data[1].plot.barh(ax=ax2)
    fake_data[2].plot.barh(ax=ax3)
    

    enter image description here