I am trying to color a scatterplot by a categorical column. Here is a sample data, the column I want to color the scatterplot by is 'cat'.
data = {
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'y': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29],
'z': [1, 2, 2, 3, 3, 4, 4, 5, 6, 6],
'cat': ['A', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A']
}
pandas_df = pd.DataFrame(data)
pyspark_df = spark.createDataFrame(pandas_df)
I created the following function to test the output. If I remove "hue" from the parameters, everything works fine, but i cannot seem to get it working correctly with 'hue'.
def facet_plot(df, x, y, color, facet_col, bins = None):
pd_df = df.toPandas()
if bins is not None:
# check col type
if pd_df[facet_col].dtype.name in ['float64', 'int64']:
# bin the facet column
pd_df['facet_col_binned']= pd.cut(pd_df[facet_col], bins = bins)
# convert intervals to midpoints
pd_df['facet_col_binned'] = pd_df['facet_col_binned'].apply(lambda interval: round(interval.mid, 1) if pd.notna(interval) else None)
pd_df['facet_col_binned'] = pd.Categorical(pd_df['facet_col_binned'])
# assigning x as 'x_binned' for remaining code
facet_col = 'facet_col_binned'
pd_df[color] = pd_df[color].astype(str)
g = sns.FacetGrid(pd_df, col=facet_col, col_wrap=4, height=5, aspect=2)
g.map(sns.scatterplot, x, y, hue=color)
# if row => then change to row_template = '{row_name}'
g.set_titles(col_template = '{col_name}')
g.set_axis_labels(x, y)
plt.show()
facet_plot(pyspark_df, 'x', 'y', color = 'cat', facet_col='cat', bins = 2)
First off, creating a more minimal example helps to pinpoint the problem:
import seaborn as sns
import pandas as pd
data = {
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'y': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29],
'z': [1, 2, 2, 3, 3, 1, 1, 2, 3, 3],
'cat': ['A', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A']
}
pd_df = pd.DataFrame(data)
g = sns.FacetGrid(pd_df, col='z', col_wrap=3, height=3, aspect=2)
g.map(sns.scatterplot, 'x', 'y', hue='cat')
The main problem is that g.map
doesn't provide the full dataframe in the call to sns.scatterplot
. It only replaces 'x'
and 'y'
with the corresponding columns of the dataframe. As such, g.map()
can't resolve ("interpret") the 'cat'
column.
One possibility is to use g.map_dataframe
instead. As no figure legend is created automatically, you'll also need to call g.add_legend()
.
A better solution is to add hue=
to sns.FacetGrid(...., hue='cat')
and leave it out in g.map(sns.scatterplot, 'x', 'y')
.
The recommended solution is to use the "figure level" version of your function. For sns.scatterplot
this is sns.relplot
. This also creates a FacetGrid
, but is more fine-tuned for the scatter plot.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
data = {
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'y': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29],
'z': [1, 2, 2, 3, 3, 1, 1, 2, 3, 3],
'cat': ['A', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A']
}
pd_df = pd.DataFrame(data)
g = sns.relplot(pd_df, x='x', y='y', hue='cat', col='z', col_wrap=3, height=3, aspect=2)
plt.show()