Search code examples
pythonpandasdata-analysis

Sorting Pandas Categorical labels after groupby


I'm using pd.cut to discretize a dataset. Everything is working great. The question I have, however, is with the Categorical object type, which is the datatype returned by pd.cut. The docs say that the Categorical object is treated like an array of strings, so I'm not surprised to see that labels are lexically sorted when grouped.

For example, the following code:

df = pd.DataFrame({'value': np.random.randint(0, 10000, 100)})

labels = []
for i in range(0, 10000, 500):
    labels.append("{0} - {1}".format(i, i + 499))

df.sort(columns=['value'], inplace=True, ascending=True)
df['value_group'] = pd.cut(df.value, range(0, 10500, 500), right=False, labels=labels)

df.groupby(['value_group'])['value_group'].count().plot(kind='bar')

Produces the following chart:

enter image description here

(notice 500-599 in the middle)

Prior to grouping, the structure is in the order I expect:

In [94]: df['value_group']
Out [94]: 
59        0 - 499
58        0 - 499
0       500 - 999
94      500 - 999
76      500 - 999
95     1000 - 1499
17     1000 - 1499
48     1000 - 1499

I've played around with this for some time and the only way that I've been able to avoid this is to precede the label with a leading alpha char, e.g. ['A) 0 - 499', 'B) 500-999', ... ] which makes me cringe. Other things I looked into are providing a custom groupby implementation, which didn't seem possible (or like the right thing either). What am I missing?


Solution

  • This has bitten me too. Probably the right fix is to improve native support for Categorical objects, but in the meantime I get around this in practice by doing a final sorting pass:

    In [104]: z = df.groupby('value_group').size()
    
    In [105]: z[sorted(z.index, key=lambda x: float(x.split()[0]))]
    Out[105]: 
    0 - 499        5
    500 - 999      6
    1000 - 1499    4
    1500 - 1999    6
    2000 - 2499    4
    2500 - 2999    6
    3000 - 3499    3
    3500 - 3999    3
    4000 - 4499    2
    4500 - 4999    6
    5000 - 5499    6
    5500 - 5999    5
    6000 - 6499    6
    6500 - 6999    2
    7000 - 7499    9
    7500 - 7999    3
    8000 - 8499    7
    8500 - 8999    6
    9000 - 9499    5
    9500 - 9999    6
    dtype: int64
    
    In [106]: z[sorted(z.index, key=lambda x: float(x.split()[0]))].plot(kind='bar')
    Out[106]: <matplotlib.axes.AxesSubplot at 0xbe87d30>
    

    demo with better order