Search code examples
pythonpandasdataframeplotgroup-by

How to specify x and y in the dataframe groupby plot


I have a dataframe which contains essentially 3 columns(I have many columns but the main idea is presented here): df["Label","Data1","Data2"]

df = pd.DataFrame({'Label': [1, 1, 1, 2, 2, 3, 3, 3, 4, 4],
                   'Data1': [0.1, 0.01, 0.15, 0.3, 0.35, 0.38, 0.44, 0.45, 0.8, 0.88],
                   'Data2': [1.4, 1.4, 1.6, 2.3, 2.5, 3.6, 3.8, 3.9, 7.3, 7.7]})

I am wondering if there is a way to plot Data1 vs Data2, groupby(Label) and give colors based on the groupby? I tried something like:

df.groupby("Label")[["Data1","Data2"]].plot(marker='.',subplots=False,ax=plt.gca())

But the figure is NOT about Data1 vs Data2.

Also I would like to know if this could be done in Matplotlib.


Solution

  • Are you trying to just plot a scatter plot with Data1 on x-axis and Data2 on y-axis, with the colours of each dot being related to the Label column? If so, you don't really need to groupby and instead can just do:

    Setup

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    
    df = pd.DataFrame({'Label': [1, 1, 1, 2, 2, 3, 3, 3, 4, 4],
                       'Data1': [0.1, 0.01, 0.15, 0.3, 0.35, 0.38, 0.44, 0.45, 0.8, 0.88],
                       'Data2': [1.4, 1.4, 1.6, 2.3, 2.5, 3.6, 3.8, 3.9, 7.3, 7.7]})
    

    For string label values

    I struggled a bit with the string values, there might be a cleaner way but this is the way that worked for me:

    Method 1 default colours

    fig,ax = plt.subplots()
    
    #default colour palette
    prop_cycle = plt.rcParams['axes.prop_cycle'] 
    #list of colours
    colors = prop_cycle.by_key()['color'] 
    
    i=0 #iterator
    
    #iterate over each label value
    for label in df['Label'].unique():
    
        #get datapoints for that label
        x = df[df['Label']==label]['Data1']
        y = df[df['Label']==label]['Data2']
        
        #specify color and label (for legend)
        ax.scatter(x,y,color=colors[i],label=label)
        i+=1
        
    #show legend
    ax.legend()
    plt.show()
    

    Method 2 color palette

    from matplotlib import cm
    
    fig,ax = plt.subplots()
    
    n = df['Label'].nunique() #num of categs
    
    #colours for rainbow palette (can use other palette)
    color = iter(cm.rainbow(np.linspace(0, 1, n))) 
    
    for label in df['Label'].unique():
        c = next(color) #change colour
        x = df[df['Label']==label]['Data1']
        y = df[df['Label']==label]['Data2']
        
        ax.scatter(x,y,color=c,label=label)
        
    ax.legend()
    plt.show()
    

    For numerical label values:

    plt.scatter(df['Data1'],df['Data2'],c=df['Label'])
    

    The c = is essentially all you need.

    You can also specifycmap='plasma' for example, to change the colour palette. List of colour maps can be found on matplotlib docs