Search code examples
pythonmatplotlibcolorscomparisonlegend

matplotlib: same legend for two data sets


I am plotting two datasets in dataframes using matplotlib. The datasets are represented by different line styles. The following is the code.

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
df1 = pd.DataFrame(np.random.randn(10, 16))
df2 = pd.DataFrame(np.random.randn(10, 16))


plt.figure()
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))

df1.plot(ax=axes[0], style='-', legend=True)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('ttl')

df2.plot(ax=axes[0], style='--', legend=True)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('ttl')

plt.show()

enter image description here

However, the sequence of colors is different for different line styles. for instance, 0 in line and 0 in dashed line have different colors. I'd like to ask for suggestions on how to get the same color sequence for both line styles.

EDIT: Changing input to

df1 = pd.DataFrame(np.random.randn(501, 16))
df2 = pd.DataFrame(np.random.randn(5001, 16))

changes the legend to all blue enter image description here


Solution

  • This is a bit hacky of a solution, but you create a list of colors with the same length as one of your dataframes, then assign these colors to each plot.

    from matplotlib import pyplot as plt
    import numpy as np
    import pandas as pd
    df1 = pd.DataFrame(np.random.randn(10, 6))
    df2 = pd.DataFrame(np.random.randn(10, 10))
    
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
    
    # to account for different numbers of columns between dfs
    if len(df2) > len(df1):
        colors = plt.cm.jet(np.linspace(0,1,len(df2)))
    else:
        colors = plt.cm.jet(np.linspace(0,1,len(df1)))
    
    df1.plot(ax=axes[0], style='-', color = colors, legend=True)
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[0].set_title('ttl')
    
    df2.plot(ax=axes[0], style='--', color = colors, legend=True)
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[0].set_title('ttl')
    
    plt.show()
    

    enter image description here