Search code examples
apache-sparkpysparkapache-spark-sql

Get correlation matrix for array in a column


I have dataframe:

data = [['t1', ['u1','u2', 'u3', 'u4', 'u5'], 1],['t2', ['u1','u7', 'u8', 'u5'], 1], ['t3', ['u1','u2', 'u7', 'u11'], 2], ['t4', ['u8','u9'], 3], ['t5', ['u9','u22', 'u11'], 3],
       ['t6', ['u5','u11', 'u22', 'u4'], 3]]
sdf = spark.createDataFrame(data, schema=['label', 'id', 'day'])
sdf.show()
+-----+--------------------+---+
|label|                  id|day|
+-----+--------------------+---+
|   t1|[u1, u2, u3, u4, u5]|  1|
|   t2|    [u1, u7, u8, u5]|  1|
|   t3|   [u1, u2, u7, u11]|  2|
|   t4|            [u8, u9]|  3|
|   t5|      [u9, u22, u11]|  3|
|   t6|  [u5, u11, u22, u4]|  3|
+-----+--------------------+---+

I want to calculate the correlation matrix (actually my dataframe is much larger):

I would like to cross the id column every other day. That is, on day=1, I dont cross IDs in that day, and set 0 for such cases. I cross the first day with the second and third, etc.

Moreover, if the label intersects with itself, then there is not 100, but 0 is given (the diagonal is 0).

And in the matrix, I would like to record the absolute value of the intersection ( how many IDs have intersected) It should probably turn out such a dataframe:

+---+-----+---+---+---+---+---+---+
|day|label| t1| t2| t3| t4| t5| t6|
+---+-----+---+---+---+---+---+---+
|  1|   t1|  0|  0|  2|  0|  0|  2|
|  1|   t2|  0|  0|  2|  1|  0|  0|
|  2|   t3|  2|  2|  0|  0|  0|  1|
|  3|   t4|  0|  1|  0|  0|  0|  0|
|  3|   t5|  0|  0|  1|  0|  0|  0|
|  3|   t6|  2|  1|  1|  0|  0|  0|
+---+-----+---+---+---+---+---+---+

And since I actually have a large dataset, I would like it not to require too much from memory and the task does not fall


Solution

  • First of all, you can use explode to flatten the lists of IDs:

    >>> from pyspark.sql.functions import explode
    >>> from pyspark.sql.types import StructType, StructField, StringType, ArrayType
    >>> schema = StructType([
    ...     StructField('label', StringType(), nullable=False),
    ...     StructField('ids', ArrayType(StringType(), containsNull=False), nullable=False),
    ...     StructField('day', StringType(), nullable=False),
    ... ])
    >>> data = [
    ...     ['t1', ['u1', 'u2', 'u3', 'u4', 'u5'], 1],
    ...     ['t2', ['u1', 'u7', 'u8', 'u5'], 1],
    ...     ['t3', ['u1', 'u2', 'u7', 'u11'], 2],
    ...     ['t4', ['u8', 'u9'], 3],
    ...     ['t5', ['u9', 'u22', 'u11'], 3],
    ...     ['t6', ['u5', 'u11', 'u22', 'u4'], 3]
    ... ]
    >>> id_lists_df = spark.createDataFrame(data, schema=schema)
    >>> df = id_lists_df.select('label', 'day', explode('ids').alias('id'))
    >>> df.show()
    +-----+---+---+                                                                 
    |label|day| id|
    +-----+---+---+
    |   t1|  1| u1|
    |   t1|  1| u2|
    |   t1|  1| u3|
    |   t1|  1| u4|
    |   t1|  1| u5|
    |   t2|  1| u1|
    |   t2|  1| u7|
    |   t2|  1| u8|
    |   t2|  1| u5|
    |   t3|  2| u1|
    |   t3|  2| u2|
    |   t3|  2| u7|
    |   t3|  2|u11|
    |   t4|  3| u8|
    |   t4|  3| u9|
    |   t5|  3| u9|
    |   t5|  3|u22|
    |   t5|  3|u11|
    |   t6|  3| u5|
    |   t6|  3|u11|
    +-----+---+---+
    only showing top 20 rows
    

    Then you can self-join the resulting data frame, filter out the unwanted rows (same day or label) and then proceed to the actual counting.

    I have the impression that your matrix will contain lots of zeros. Do you need a "physical" matrix or is a count per day and pair of labels sufficient?

    If you don't need a "physical" matrix, you can use regular aggregations (group by day and labels and then count):

    >>> df2 = df.withColumnRenamed('label', 'label2').withColumnRenamed('day', 'day2')
    >>> counts = df.join(df2, on='id') \
    ...     .where(df.label != df2.label2) \
    ...     .where(df.day != df2.day2) \
    ...     .groupby(df.day, df.label, df2.label2) \
    ...     .count() \
    ...     .orderBy(df.label, df2.label2)
    >>> 
    >>> counts.show()
    +---+-----+------+-----+                                                        
    |day|label|label2|count|
    +---+-----+------+-----+
    |  1|   t1|    t3|    2|
    |  1|   t1|    t6|    2|
    |  1|   t2|    t3|    2|
    |  1|   t2|    t4|    1|
    |  1|   t2|    t6|    1|
    |  2|   t3|    t1|    2|
    |  2|   t3|    t2|    2|
    |  2|   t3|    t5|    1|
    |  2|   t3|    t6|    1|
    |  3|   t4|    t2|    1|
    |  3|   t5|    t3|    1|
    |  3|   t6|    t1|    2|
    |  3|   t6|    t2|    1|
    |  3|   t6|    t3|    1|
    +---+-----+------+-----+
    
    >>> counts.explain()
    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- Sort [label#0 ASC NULLS FIRST, label2#493 ASC NULLS FIRST], true, 0
       +- Exchange rangepartitioning(label#0 ASC NULLS FIRST, label2#493 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=2010]
          +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[count(1)])
             +- Exchange hashpartitioning(day#2, label#0, label2#493, 200), ENSURE_REQUIREMENTS, [plan_id=2007]
                +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[partial_count(1)])
                   +- Project [label#0, day#2, label2#493]
                      +- SortMergeJoin [id#7], [id#504], Inner, (NOT (label#0 = label2#493) AND NOT (day#2 = day2#497))
                         :- Sort [id#7 ASC NULLS FIRST], false, 0
                         :  +- Exchange hashpartitioning(id#7, 200), ENSURE_REQUIREMENTS, [plan_id=1999]
                         :     +- Generate explode(ids#1), [label#0, day#2], false, [id#7]
                         :        +- Filter (size(ids#1, true) > 0)
                         :           +- Scan ExistingRDD[label#0,ids#1,day#2]
                         +- Sort [id#504 ASC NULLS FIRST], false, 0
                            +- Exchange hashpartitioning(id#504, 200), ENSURE_REQUIREMENTS, [plan_id=2000]
                               +- Project [label#501 AS label2#493, day#503 AS day2#497, id#504]
                                  +- Generate explode(ids#502), [label#501, day#503], false, [id#504]
                                     +- Filter (size(ids#502, true) > 0)
                                        +- Scan ExistingRDD[label#501,ids#502,day#503]
    

    If you need "physical" matrices, you can work with MLlib as suggested in the first answer, or you can use pivot on label2 instead of using it as a grouping column:

    >>> counts_pivoted = df.join(df2, on='id') \
    ...     .where(df.label != df2.label2) \
    ...     .where(df.day != df2.day2) \
    ...     .groupby(df.day, df.label) \
    ...     .pivot('label2') \
    ...     .count() \
    ...     .drop('label2') \
    ...     .orderBy('label') \
    ...     .fillna(0)
    >>> counts_pivoted.show()                                                       
    +---+-----+---+---+---+---+---+---+                                             
    |day|label| t1| t2| t3| t4| t5| t6|
    +---+-----+---+---+---+---+---+---+
    |  1|   t1|  0|  0|  2|  0|  0|  2|
    |  1|   t2|  0|  0|  2|  1|  0|  1|
    |  2|   t3|  2|  2|  0|  0|  1|  1|
    |  3|   t4|  0|  1|  0|  0|  0|  0|
    |  3|   t5|  0|  0|  1|  0|  0|  0|
    |  3|   t6|  2|  1|  1|  0|  0|  0|
    +---+-----+---+---+---+---+---+---+
    
    >>> counts_pivoted.explain()
    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- Project [day#2, label#0, coalesce(t1#574L, 0) AS t1#616L, coalesce(t2#575L, 0) AS t2#617L, coalesce(t3#576L, 0) AS t3#618L, coalesce(t4#577L, 0) AS t4#619L, coalesce(t5#578L, 0) AS t5#620L, coalesce(t6#579L, 0) AS t6#621L]
       +- Sort [label#0 ASC NULLS FIRST], true, 0
          +- Exchange rangepartitioning(label#0 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=2744]
             +- Project [day#2, label#0, __pivot_count(1) AS count AS `count(1) AS count`#573[0] AS t1#574L, __pivot_count(1) AS count AS `count(1) AS count`#573[1] AS t2#575L, __pivot_count(1) AS count AS `count(1) AS count`#573[2] AS t3#576L, __pivot_count(1) AS count AS `count(1) AS count`#573[3] AS t4#577L, __pivot_count(1) AS count AS `count(1) AS count`#573[4] AS t5#578L, __pivot_count(1) AS count AS `count(1) AS count`#573[5] AS t6#579L]
                +- HashAggregate(keys=[day#2, label#0], functions=[pivotfirst(label2#493, count(1) AS count#559L, t1, t2, t3, t4, t5, t6, 0, 0)])
                   +- Exchange hashpartitioning(day#2, label#0, 200), ENSURE_REQUIREMENTS, [plan_id=2740]
                      +- HashAggregate(keys=[day#2, label#0], functions=[partial_pivotfirst(label2#493, count(1) AS count#559L, t1, t2, t3, t4, t5, t6, 0, 0)])
                         +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[count(1)])
                            +- Exchange hashpartitioning(day#2, label#0, label2#493, 200), ENSURE_REQUIREMENTS, [plan_id=2736]
                               +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[partial_count(1)])
                                  +- Project [label#0, day#2, label2#493]
                                     +- SortMergeJoin [id#7], [id#543], Inner, (NOT (label#0 = label2#493) AND NOT (day#2 = day2#497))
                                        :- Sort [id#7 ASC NULLS FIRST], false, 0
                                        :  +- Exchange hashpartitioning(id#7, 200), ENSURE_REQUIREMENTS, [plan_id=2728]
                                        :     +- Generate explode(ids#1), [label#0, day#2], false, [id#7]
                                        :        +- Filter (size(ids#1, true) > 0)
                                        :           +- Scan ExistingRDD[label#0,ids#1,day#2]
                                        +- Sort [id#543 ASC NULLS FIRST], false, 0
                                           +- Exchange hashpartitioning(id#543, 200), ENSURE_REQUIREMENTS, [plan_id=2729]
                                              +- Project [label#540 AS label2#493, day#542 AS day2#497, id#543]
                                                 +- Generate explode(ids#541), [label#540, day#542], false, [id#543]
                                                    +- Filter (size(ids#541, true) > 0)
                                                       +- Scan ExistingRDD[label#540,ids#541,day#542]
    

    The values are not completely identical to your example, but I assume that werner's comment is correct.

    The pivot variant is probably less efficient. If the list of possible labels is available beforehand, you can save some time by passing it as the second argument of pivot.