Search code examples
pythonpandasmatplotlibseaborn

Change color for displot based on column


I have a displot: Displot for corn and soy I want the data for Corn to be a different color (say Red and Green) and the data for soy to be (Orange and Blue)

df_melted = {'Sink_ID': ['100012_HUC8_09020204', '100012_HUC8_09020204', '100017_HUC8_09020204', '100017_HUC8_09020204', '100029_HUC8_09020204', '100029_HUC8_09020204', '100136_HUC8_09020205', '100136_HUC8_09020205', '100147_HUC8_09020204', '100147_HUC8_09020204'], 'Year': [2005, 2005, 2005, 2005, 2005, 2005, 2005, 2005, 2005, 2005], 'Date_Range': ['07-10', '07-10', '07-10', '07-10', '07-10', '07-10', '07-10', '07-10', '07-10', '07-10'], 'WUVR': [1.1149566576574306, 1.0045231416606328, 0.5370876305214706, 0.7059439317324633, 0.9847125932507388, 1.0293399058565504, 1.0809386044812537, 1.383177947482075, 0.6703937259463105, 0.7220824311880709], 'Temp': [24.088787078857425, 21.557071685791016, 24.08528709411621, 21.73164367675781, 24.037073135375977, 21.661357879638672, 23.882572174072266, 21.39950180053711, 24.08528709411621, 21.73164367675781], 'Precip': [0.0837142914533615, 2.6772143840789795, 0.9112856984138488, 4.665500164031982, 0.7224286198616028, 4.315714359283447, 0.1414999961853027, 2.28857159614563, 0.9112856984138488, 4.665500164031982], 'Crop': ['Corn', 'Corn', 'Corn', 'Corn', 'Corn', 'Corn', 'Corn', 'Corn', 'Corn', 'Corn'], 'NDVI_Type': ['NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet', 'NDVI_wet'], 'NDVI_Value': [0.847915947437286, 0.835569024085999, 0.328437134623528, 0.528863519430161, 0.636810958385468, 0.705004304647446, 0.644810974597931, 0.761542975902557, 0.372550383210182, 0.535710155963898], 'Intervals': ['07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06', '07-10 to 08-06']}

sns.set_theme(style="white", palette=None)
# Create the displot
g = sns.displot(data=df_melted, x='NDVI_Value', hue='NDVI_Type', row='Intervals', col='Crop',
                 binwidth=0.01, height=3, aspect=2.5, facet_kws=dict(margin_titles=False), 
                 row_order=df_intervals['Intervals'].tolist())

# Set the xlim for each facet
for ax in g.axes.flatten():
    ax.set_xlim(0, 1)
    ax.title.set_fontsize(14)
plt.xticks([0, 0.25, 0.5, 0.75, 1], labels=['0', '0.25', '0.5', '0.75', '1'])
plt.show()

I tried changing the color palette but that affected both columns and I want to affect each column separately.


Solution

  • I don't think you can do this with a displot itself. You can instead create a FacetGrid and then manually add all the histplots with the required colour palettes for each column. This is demonstrated below with some mock data:

    import numpy as np
    import seaborn as sns
    import pandas as pd
    
    
    nvalues = 1000
    mus = [10, 11.5]
    crops = ["Corn", "Soy"]
    intervals = ["A", "B"]
    types = ["Type 1", "Type 2"]
    
    # some mock data
    data = {"crop": [], "interval": [], "type": [], "value": []}
    for crop in crops:
        for interval in intervals:
            for mu, type in zip(mus, types):
                data["crop"].extend(nvalues * [crop])
                data["type"].extend(nvalues * [type])
                data["interval"].extend(nvalues * [interval])
                data["value"].extend((mu + np.random.randn(nvalues)).tolist())
    
    df = pd.DataFrame(data)
    
    nbins = 50  # number of histogram bins
    
    # create the grid
    g = sns.FacetGrid(data=df, hue="type", row="interval", col="crop")
    
    # set the palettes
    palettes = {
        "Corn": ["red", "green"],
        "Soy": ["orange", "blue"],
    }
    
    # iterate through the data
    for j, crop in enumerate(crops):
        dfcrop = df.loc[df["crop"] == crop].reset_index(drop=True)
    
        for i, (inter, ax) in enumerate(zip(intervals, g.axes)):
            ivals = dfcrop.loc[dfcrop["interval"] == inter].reset_index(drop=True)
            hax = sns.histplot(
                data=ivals,
                x="value",
                hue="type",
                palette=palettes[crop],
                bins=nbins,
                ax=ax[j],
                legend=True if i == 0 else False,  # only put legend on the first plot
            )
    
            # add a title like the displot titles
            hax.set_title(f"interval = {inter} | crop = {crop}", fontsize=10)
    

    enter image description here