Search code examples
apache-sparkpysparkapache-spark-sqlrdd

Co-occurence matrix on multilabel data


I have a dataset of users, each user can belong to multiple classes :

user1, A
user1, B
user1, C
user2, A
user2, C
user3, B
user3, C

For example in this case, user1 belongs to class A, B and C. I would like to know how many unique users are in each pair of classes (each entry in the table corresponds to the number of shared unique users):

  | A | B | C
A | 2 | 1 | 1
B | 1 | 2 | 2
C | 1 | 2 | 2

I wonder how is it possible to do it with dataframes and RDD in PySpark? I think maybe I need to reshape the data/pivot but the solutions I come up with seem a bit complicated for the task...

Thank you!


Solution

  • Self-join by class and crosstab

    (df.withColumnRenamed("class", "class_a")
        .join(df.withColumnRenamed("class", "class_b"), ["user"])
        .crosstab("class_a", "class_b")
        .orderBy("class_a_class_b")
        .show())
    
    # +---------------+---+---+---+ 
    # |class_a_class_b|  A|  B|  C|
    # +---------------+---+---+---+ 
    # |              A|  2|  1|  2|
    # |              B|  1|  2|  2|
    # |              C|  2|  2|  3|
    # +---------------+---+---+---+
    

    Apply distinct before if you want only unique (user, class) pairs.