Search code examples
pythonpandasmatplotlibseabornscatter-plot

How to produce a scatter plot with markers and colors determined by categories in different columns


I want to plot a dataset with different clusters.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.cluster

rng = np.random.default_rng(seed=5)

df_1_3 = pd.DataFrame(rng.normal(loc=(1, 3), size=(30, 2), scale=0.50), columns=["x", "y"])
df_5_1 = pd.DataFrame(rng.normal(loc=(5, 1), size=(30, 2), scale=0.25), columns=["x", "y"])
df_5_5 = pd.DataFrame(rng.normal(loc=(5, 5), size=(30, 2), scale=0.25), columns=["x", "y"])

df = pd.concat([df_1_3, df_5_1, df_5_5], keys=["df_1_3", "df_5_1", "df_5_5"])

A cluster algorithm will calculate the cluster labels:

model = sklearn.cluster.AgglomerativeClustering(...)

df["cluster"] = model.fit_predict(df[["x", "y"]]) # [0, 0, 0, ... 1, 1, 1 ... 2, 2, 2] 
df["cluster"] = df["cluster"].astype("category")

I want to visualize the data in one plot. Each original data should be distinguishable by an individual marker, and the label should be visualized by the color.

To clarify, if you set the origin of all three data close to each other. The Algorithm would create just one cluster (aka one category / color), but the markers shall be depend on the original keys, 'df_1_3', 'df_5_1', and 'df_5_5'.

Actually I nearly got the result with:

fig, ax = plt.subplots()
for marker, (name, sdf) in zip(["o", "s", "^", "d"], df.groupby(level=0)):
    sdf.plot.scatter(x="x", y="y", c="cluster", marker=marker, cmap="viridis", ax=ax)

but with the caveat that the color bar is displayed three times enter image description here

How do I get rid of the redundant colorbars?


Solution

  • Using seaborn you can do this without using a for loop and get a cleaner looking plot:

    import seaborn as sns
    
    sns.scatterplot(data=df, x='x', y='y', hue='cluster', style='cluster', markers=["o", "^", "d"], palette="viridis")
    

    enter image description here

    To keep the color and the marker separate, it is best to reset the dataframe index, and use the keys, in level=0 of the index, for the markers.

    # reset the index
    df = df.reset_index(level=0, names=['key'])
    
    # plot
    ax = sns.scatterplot(data=df, x='x', y='y', hue='cluster', style='key', markers=["o", "^", "d"], palette="viridis")
    sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
    

    enter image description here

    df.head() after df.reset_index(level=0, names=['key'])

          key         x         y cluster
    0  df_1_3  0.599034  2.337821       0
    1  df_1_3  0.875819  3.210223       0
    2  df_1_3  1.568023  3.054853       0
    3  df_1_3  0.723676  2.607610       0
    4  df_1_3  1.374373  3.817392       0