Search code examples
pythonpandasdataframeselectmulti-index

How to avoid a chained apply when selecting from multiindex?


I have a dataframe like this:

import numpy as np
import pandas as pd


df = pd.DataFrame({
    'ind1': list('AAABBBCCC'),
    'ind2': list(map(int, list('123123123'))),
    'val1': [0, 1, 2, -1, -4, -5, 10, 11, 4],
    'val2': [0.1, 0.2, -0.2, 0.1, 0.2, 0.2, -0.1, 2, 0.1]
})

df = df.set_index(['ind1', 'ind2'])

           val1  val2
ind1 ind2            
A    1        0   0.1
     2        1   0.2
     3        2  -0.2
B    1       -1   0.1
     2       -4   0.2
     3       -5   0.2
C    1       10  -0.1
     2       11   2.0
     3        4   0.1

I want to select all entries for which the absolute value of differences between the values in val1 are increasing.

I currently do it as follows:

m_incr = (
    df.groupby('ind1')['val1']
      .apply(lambda x: np.diff(abs(x)))
      .apply(lambda x: all(eli > 0 for eli in x))
)

df_incr = df[m_incr[df.index.get_level_values('ind1')].values]

which gives me the desired outcome:

           val1  val2
ind1 ind2            
A    1        0   0.1
     2        1   0.2
     3        2  -0.2
B    1       -1   0.1
     2       -4   0.2
     3       -5   0.2

My question is whether there is a more straightforward/efficient way that avoids the chained applys.


Solution

  • Use GroupBy.transform for return Series with same size like original DataFrame:

    mask = df.groupby('ind1')['val1'].transform(lambda x: (np.diff(abs(x)) > 0).all())
    

    And then filter by mask with boolean indexing:

    print (df[mask])
    

    All together:

    print (df[df.groupby('ind1')['val1'].transform(lambda x: (np.diff(abs(x)) > 0).all())])
    

               val1  val2
    ind1 ind2            
    A    1        0   0.1
         2        1   0.2
         3        2  -0.2
    B    1       -1   0.1
         2       -4   0.2
         3       -5   0.2
    

    Detail:

    print (mask)
    ind1  ind2
    A     1        True
          2        True
          3        True
    B     1        True
          2        True
          3        True
    C     1       False
          2       False
          3       False
    Name: val1, dtype: bool