Search code examples
pandasnumpycountbinary-matrix

counting in pandas dataframe, for each a_i, a_j of A, count number of times that a_i and a_j are listed with a same b


I have a big dataframe (1M lines) with two columns A,B.

For each couple a_i, a_j of A, I want to know the number of b in B such that there is both rows a_i,b and a_j,b

For example :

A B
a1 b1
a2 b1
a3 b2
a1 b3
a3 b3
a4 b3
a2 b4
a1 b5
a3 b5
a3 b6
a4 b6

here I have, among others, the couple a1,a3 that share b3 and b5

The result would be the following matrix (that, by definition, is symetric) :

a1 a2 a3 a4
a1 xx 1 2 1
a2 1 xx 0 0
a3 2 0 xx 2
a4 1 0 2 xx

I think that the following would work:

df = pd.DataFrame({'A' : ['a1','a2','a3','a1','a3','a4','a2','a1','a3','a3','a4'],
 'B':['b1','b1','b2','b3','b3','b3','b4','b5','b5','b6','b6']})

df_dum = df.set_index('A')['B'].str.get_dummies().reset_index()
df_dum = df_dum.groupby('A').sum()
np_cnt = df_dum.to_numpy()
np_mul = np.matmul(np_cnt,np_cnt.T)

but it takes way too much time and memory and does not run with my 1M rows. In addition, the diagonal is coumputed whereas I don't need it and I think that passing through the dummy is not so much a good idea especially because the resulting binary is really sparse.

But I don't have any more ideas...

What would you propose ?

EDIT:

for a little bit more of context, let's say that A are students and B are courses. In the end I want to know for any two students, how much courses they have together. And that for every couple of students that at least share a course. If that makes more sense :)


Solution

  • Try with itertools.permutations:

    import itertools
    
    sets = df.groupby('B')['A'].apply(lambda x : list(itertools.permutations(x, 2))).explode().tolist()
    sets = pd.DataFrame(sets)
    
    index = df["A"].unique()
    output = pd.crosstab(sets[0],sets[1],rownames=[None],colnames=[None]).reindex(index).reindex(index, axis=1)
    
    >>> output
        a1  a2  a3  a4
    a1   0   1   2   1
    a2   1   0   0   0
    a3   2   0   0   2
    a4   1   0   2   0
    

    If you want to mask the cells where index and columns are the same with "xx":

    output = output.mask(output.index.values[:,None] == output.columns.values[None,:]).fillna("xx")
    
    >>> output
         a1   a2   a3   a4
    a1   xx  1.0  2.0  1.0
    a2  1.0   xx  0.0  0.0
    a3  2.0  0.0   xx  2.0
    a4  1.0  0.0  2.0   xx