Search code examples
pythonapache-sparkpysparkrdd

Creating combination and sum of value lists with existing key - Pyspark


My question is similar to the one given here, however, I have an additional field that I would like to get the sum from, that is, my RDD is as follows (I show it as a data frame)

+----------+----------------+----------------+
|    c1    |        c2      |      val       |
+----------+----------------+----------------+
|        t1|         [a, b] |        [11, 12]|
|        t2|     [a, b, c ] |    [13, 14, 15]|
|        t3|   [a, b, c, d] |[16, 17, 18, 19]|
+----------+----------------+----------------+

and I would like to get something like this:

        +----------+----------------+----------------+
        |    c1    |        c2      |     sum(val)   |
        +----------+----------------+----------------+
        |        t1|         [a, b] |        23      |
        |        t2|         [a, b] |        27      |
        |        t2|         [a, c] |        28      |
        |        t2|         [b, d] |        29      |
        |        t3|         [a, b] |        33      |
        |        t3|         [a, c] |        34      |
        |        t3|         [a, d] |        35      |
        |        t3|         [b, c] |        35      |
        |        t3|         [b, d] |        36      |
        |        t3|         [c, d] |        37      |
        +----------+----------------+----------------+

with the following code I get the first two columns

def combinations(row):
    l = row[1]
    k = row[0]
    m = row[2]
return [(k, v) for v in itertools.combinations(l, 2)]

a.map(combinations).flatMap(lambda x: x).take(5)

With this code I try to get the third column but I get more rows

    def combinations(row):
            l = row[1]
            k = row[0]
            m = row[2]
    return [(k, v, x) for v in itertools.combinations(l, 2) for x in map(sum, itertools.combinations(m, 2)) ]
        
a.map(combinations).flatMap(lambda x: x).take(5)

I would appreciate any help, thanks.


Solution

  • Try below:

    a = sc.parallelize([
        (1, [1,2,3,4], [11,12,13,14]),
        (2, [3,4,5,6], [15,16,17,18]),
        (3, [-1,2,3,4], [19,20,21,22])
      ])
    
    def combinations(row):
        l = row[1]
        k = row[0]
        m = row[2]
        return [(k, v, x) for v in itertools.combinations(l, 2) for x in map(sum, itertools.combinations(m, 2))]
    
    a.map(combinations).flatMap(lambda x: x).take(5)