I have a pandas MultiIndex object where the first level is a regular increasing index of ints, and the second level contains other integers that may or may not repeat for different 'frst' index values:
lst = list(filter(lambda x: x[1]%5 == x[0] or x[1]%4 == x[0],[(i,j) for i in range(5) for j in range(0, 20, 2)]))
mi = pd.MultiIndex.from_tuples(lst).rename(['frst', 'scnd'])
# mi = MultiIndex([(0, 0),(0, 4),(0, 8),(0, 10),(0, 12),(0, 16),(1, 6),(1, 16),(2, 2),(2, 6),(2, 10),(2, 12),(2, 14),(2, 18),(3, 8),(3, 18),(4, 4),(4, 14)], names=['frst', 'scnd'])
For a given frst
value (e.g. frst_idx = 0
) and some shift
, I need to find all indices where frst
is frst_idx+shift
, and scnd
is shared between frst_idx
and frst_idx+shift
.
So for example:
frst_idx = 0
, shift = 3
should output [8]
because the MultiIndex above contains both (0, 8)
and (3, 8)
.frst_idx = 1
, shift = 1
should output [6]
because (1, 6)
and (2, 6)
are both in the indexSo I'm hoping for a function that can take these args and return a pd.Series of all matching scnd
values:
my_func(multi_index=mi, frst_idx=0, shift=3) ==> pd.Series([8])
Doing this iteratively is very expensive (O(n^2)
), I'm hoping there's some pandas magic to do this faster.
I found the following solution:
# reminder: $mi is a MultiIndex, mi.names = ['frst', 'scnd']
# assume some integer values for $frst_idx1, $shift
scnd_indices1 = mi[mi.get_level_values('frst') == frst_idx1].drop_level('frst')
frst_idx2 = frst_idx1 + shift
scnd_indices2 = mi[mi.get_level_values('frst') == frst_idx2].drop_level('frst')
result = scnd_indices1.intersection(scnd_indices2).to_series().reset_index(drop=True)