I am trying to add a horizontal line in a scatter plot based on a column of the dataframe - i got the following error: ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
x_line = datLong.groupby('ctr1').agg({'maxx': ['mean']})
for country in datLong.ctr1.unique():
temp_df = plt.figure(country)
temp_df = datLong[datLong.ctr1 == country]
ax1 = temp_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label = 'xx', linewidth =3, alpha = 0.7, figsize=(7,4))
plt.title(country)
plt.axvline(x=x_line) ### this is the line that is causing this error
plt.show()
print (ax1)
The problem seems to be related to the dataframe. But I can figure out what it is? can anybody help me
x_line
contains the values for all the countries. With x_line.loc[country]
you'd get the value for that country. Because it returns an array (of just one element), and axvline
only accepts single values, you can select its first element (x_line.loc[country][0]
).
Note that plt.figure
creates a figure, and pandas plot without the ax=
parameter also creates a new figure. So, either you should leave out plt.figure()
, or explicitly create an ax
to draw on.
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
datLong = pd.DataFrame({'ctr1': np.repeat(['country 1', 'country 2'], 20),
'x': np.tile(np.arange(20), 2),
'maxx': np.random.randn(40) + 10,
'Price': np.random.randn(40) * 10 + 200})
x_line = datLong.groupby('ctr1').agg({'maxx': ['mean']})
for country in datLong.ctr1.unique():
temp_df = datLong[datLong.ctr1 == country]
ax1 = temp_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label='xx', linewidth=3, alpha=0.7,
figsize=(7, 4))
ax1.figure.canvas.set_window_title(country)
ax1.set_title(country)
ax1.axvline(x=x_line.loc[country][0])
plt.show()
As groupby
already creates the dataframes per country, you could rewrite the code making use of groupby
(without needing x_line
):
for country, country_df in datLong.groupby('ctr1'):
ax1 = country_df.plot(kind='scatter', x='x', y='Price', color='#d95f0e', label='xx', linewidth=3, alpha=0.7,
figsize=(7, 4))
ax1.figure.canvas.set_window_title(country)
ax1.set_title(country)
ax1.axvline(x=country_df['maxx'].mean())
plt.show()