Search code examples
matplotlibseabornscalingkernel-densityjointplot

How to scale marginal kdeplot of seaborn jointplot with imbalanced categorical data


How to scale marginal kdeplot of seaborn jointplot?

Let's imagine that we have 1000 datum of kind 'a', 100 datum of kind 'b', and '100' datum of kind 'c'.

In this case, the marginal kdeplot's scale doesn't seem identical because the size of categorical data is quite different.

How do I make these identical?

I make a toy script like below:

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

ax, ay = 1 * np.random.randn(1000) + 2, 1 * np.random.randn(1000) + 2
bx, by = 1 * np.random.randn(100) + 3, 1 * np.random.randn(100) + 3
cx, cy = 1 * np.random.randn(100) + 4, 1 * np.random.randn(100) + 4

a = [{'x': x, 'y': y, 'kind': 'a'} for x, y in zip(ax, ay)]
b = [{'x': x, 'y': y, 'kind': 'b'} for x, y in zip(bx, by)]
c = [{'x': x, 'y': y, 'kind': 'c'} for x, y in zip(cx, cy)]

df = pd.concat([pd.DataFrame.from_dict(a), pd.DataFrame.from_dict(b), pd.DataFrame.from_dict(c)], ignore_index=True)

print(df)
             x         y kind
0     2.500866  2.700925    a
1    -0.386057  3.322318    a
2     1.691078  2.558366    a
3     2.235042 -0.113836    a
4     3.331039  1.138366    a
...        ...       ...  ...
1195  3.703245  2.935332    c
1196  1.806040  2.842754    c
1197  5.431313  5.377297    c
1198  3.873162  6.200356    c
1199  4.111234  3.038126    c

[1200 rows x 3 columns]

sns.jointplot(data=df, x='x', y='y', hue="kind")
plt.show()

enter image description here


Solution

  • You can use marginal_kws= to add keywords for the marginal plots. In this case, the marginals use sns.kdeplot which has parameters such as commmon_norm and multiple.

    import seaborn as sns
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    ax, ay = 1 * np.random.randn(1000) + 2, 1 * np.random.randn(1000) + 2
    bx, by = 1 * np.random.randn(100) + 3, 1 * np.random.randn(100) + 3
    cx, cy = 1 * np.random.randn(100) + 4, 1 * np.random.randn(100) + 4
    
    a = [{'x': x, 'y': y, 'kind': 'a'} for x, y in zip(ax, ay)]
    b = [{'x': x, 'y': y, 'kind': 'b'} for x, y in zip(bx, by)]
    c = [{'x': x, 'y': y, 'kind': 'c'} for x, y in zip(cx, cy)]
    
    df = pd.concat([pd.DataFrame.from_dict(a), pd.DataFrame.from_dict(b), pd.DataFrame.from_dict(c)], ignore_index=True)
    
    sns.jointplot(data=df, x='x', y='y', hue="kind" , marginal_kws={'common_norm':False})
    plt.show()
    

    sns.jointplot without common norm