Search code examples

Create two subplots in line with two y axes each

This matplotlib tutorial shows how to create a plot with two y axes (two different scales):

import numpy as np
import matplotlib.pyplot as plt

def two_scales(ax1, time, data1, data2, c1, c2):

    ax2 = ax1.twinx()

    ax1.plot(time, data1, color=c1)
    ax1.set_xlabel('time (s)')

    ax2.plot(time, data2, color=c2)
    return ax1, ax2

# Create some mock data
t = np.arange(0.01, 10.0, 0.01)
s1 = np.exp(t)
s2 = np.sin(2 * np.pi * t)

# Create axes
fig, ax = plt.subplots()
ax1, ax2 = two_scales(ax, t, s1, s2, 'r', 'b')

# Change color of each axis
def color_y_axis(ax, color):
    """Color your axes."""
    for t in ax.get_yticklabels():
    return None

color_y_axis(ax1, 'r')
color_y_axis(ax2, 'b')

The result is this: enter image description here

My question: how would you modify the code to create two subplots like this one, only horizontally aligned? I would do something like

fig, ax = plt.subplots(1,2,figsize=(15, 8))
###plot something here
###plot something here

but then how do you make sure that the fig, ax = plt.subplots() called to create the axes does not clash with the fig, ax = plt.subplots(1,2,figsize=(15, 8)) you call to create the horizontally aligned canvases?


  • You would create two subplots fig, (ax1, ax2) = plt.subplots(1,2) and apply two_scales to each of them.

    import numpy as np
    import matplotlib.pyplot as plt
    def two_scales(ax1, time, data1, data2, c1, c2):
        ax2 = ax1.twinx()
        ax1.plot(time, data1, color=c1)
        ax1.set_xlabel('time (s)')
        ax2.plot(time, data2, color=c2)
        return ax1, ax2
    # Create some mock data
    t = np.arange(0.01, 10.0, 0.01)
    s1 = np.exp(t)
    s2 = np.sin(2 * np.pi * t)
    # Create axes
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,4))
    ax1, ax1a = two_scales(ax1, t, s1, s2, 'r', 'b')
    ax2, ax2a = two_scales(ax2, t, s1, s2, 'gold', 'limegreen')
    # Change color of each axis
    def color_y_axis(ax, color):
        """Color your axes."""
        for t in ax.get_yticklabels():
    color_y_axis(ax1, 'r')
    color_y_axis(ax1a, 'b')
    color_y_axis(ax2, 'gold')
    color_y_axis(ax2a, 'limegreen')

    enter image description here