Search code examples
pythonpandasmatplotlibaxvline

Plot axvline using a dataframe column error: ValueError: The truth value of a DataFrame is ambiguous.


I am trying to add a horizontal line in a scatter plot based on a column of the dataframe - i got the following error: ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().


x_line = datLong.groupby('ctr1').agg({'maxx': ['mean']})

for country in datLong.ctr1.unique():
    temp_df = plt.figure(country)
    temp_df = datLong[datLong.ctr1 == country]
    ax1 = temp_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label = 'xx', linewidth =3, alpha = 0.7, figsize=(7,4))    
   
    plt.title(country)
    plt.axvline(x=x_line) ### this is the line that is causing this error
 
    plt.show()
print (ax1)

The problem seems to be related to the dataframe. But I can figure out what it is? can anybody help me


Solution

  • x_line contains the values for all the countries. With x_line.loc[country] you'd get the value for that country. Because it returns an array (of just one element), and axvline only accepts single values, you can select its first element (x_line.loc[country][0]).

    Note that plt.figure creates a figure, and pandas plot without the ax= parameter also creates a new figure. So, either you should leave out plt.figure(), or explicitly create an ax to draw on.

    from matplotlib import pyplot as plt
    import numpy as np
    import pandas as pd
    
    datLong = pd.DataFrame({'ctr1': np.repeat(['country 1', 'country 2'], 20),
                            'x': np.tile(np.arange(20), 2),
                            'maxx': np.random.randn(40) + 10,
                            'Price': np.random.randn(40) * 10 + 200})
    
    x_line = datLong.groupby('ctr1').agg({'maxx': ['mean']})
    
    for country in datLong.ctr1.unique():
        temp_df = datLong[datLong.ctr1 == country]
        ax1 = temp_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label='xx', linewidth=3, alpha=0.7,
                           figsize=(7, 4))
        ax1.figure.canvas.set_window_title(country)
        ax1.set_title(country)
        ax1.axvline(x=x_line.loc[country][0])
        plt.show()
    

    As groupby already creates the dataframes per country, you could rewrite the code making use of groupby (without needing x_line):

    for country, country_df in datLong.groupby('ctr1'):
        ax1 = country_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label='xx', linewidth=3, alpha=0.7,
                           figsize=(7, 4))
        ax1.figure.canvas.set_window_title(country)
        ax1.set_title(country)
        ax1.axvline(x=country_df['maxx'].mean())
        plt.show()