Search code examples
pythondataframepysparkgraphframes

Expand array Column of PySpark DataFrame


I am having of transferring a DataFrame into a GraphFrame using the data below. Let's consider a column of Authors in a dataframe containing an array of Strings like the one below:

+-----------+------------------------------------+
|ArticlePMID|               Authors              |
+-----------+------------------------------------+
|    PMID1  |['Author 1', 'Author 2', 'Author 3']|
|    PMID2  |['Author 4', 'Author 5']            |
+-----------+------------------------------------+

In the data table, we have a list of authors who collaborated together on the same paper. Now I want to expand the second column into a new dataframe containing the following structure:

+---------------+---------------+ 
| Collaborator1 | Collaborator2 |
+---------------+---------------+ 
| 'Author 1'    | 'Author 2'    |
| 'Author 1'    | 'Author 3'    |
| 'Author 2'    | 'Author 3'    |
| 'Author 4'    | 'Author 5'    |
+---------------+---------------+

I tried to use the explode function, but that only expands the array into a single column of authors and I lose the collaboration network.

Can some please tell me how to go around this?


Solution

  • As long as you're using pyspark version 2.1 or above, you can use posexplode followed by a join:

    First explode with the position in the array:

    from pyspark.sql.functions import posexplode
    exploded = df.select("*", posexplode("Authors").alias("pos", "Author"))
    exploded.show()
    #+-----------+--------------------+---+--------+
    #|ArticlePMID|             Authors|pos|  Author|
    #+-----------+--------------------+---+--------+
    #|      PMID1|[Author 1, Author...|  0|Author 1|
    #|      PMID1|[Author 1, Author...|  1|Author 2|
    #|      PMID1|[Author 1, Author...|  2|Author 3|
    #|      PMID2|[Author 4, Author 5]|  0|Author 4|
    #|      PMID2|[Author 4, Author 5]|  1|Author 5|
    #+-----------+--------------------+---+--------+
    

    Now join the exploded DataFrame to itself on the ArticlePMID column and select only the columns where the left side table's pos is less than the right side table's.

    exploded.alias("l").join(exploded.alias("r"), on="ArticlePMID", how="inner")\
        .where("l.pos < r.pos")\
        .selectExpr("l.Author AS Collaborator1", "r.Author AS Collaborator2")\
        .show()
    #+-------------+-------------+
    #|Collaborator1|Collaborator2|
    #+-------------+-------------+
    #|     Author 1|     Author 2|
    #|     Author 1|     Author 3|
    #|     Author 2|     Author 3|
    #|     Author 4|     Author 5|
    #+-------------+-------------+
    

    Using the pos to filter is to avoid having the same pair of Authors listed both ways.