Search code examples
sqlapache-sparkpysparkspark-graphxgraphframes

How to do this transformation in SQL/Spark/GraphFrames


I've a table containing the following two columns:

Device-Id    Account-Id
d1           a1   
d2           a1
d1           a2
d2           a3
d3           a4
d3           a5 
d4           a6
d1           a4

Device-Id is the unique Id of the device on which my app is installed and Account-Id is the id of a user account. A user can have multiple devices and can create multiple accounts on the same device(eg. d1 device has a1, a2 and a3 accounts set up).

I want to find unique actual users(should be represented as a new column with some unique UUID in the generated table) and the transformation I'm looking for, generates the following table:

Unique-User-Id    Devices-Used    Accounts-Used
uuid1             [d1, d2, d3]    [a1, a2, a3, a4, a5]   
uuid2             [d4]            [a6]

The idea behind the above generated table is that an actual user, uuid1, has an account a1 set up on their devices d1 and d2, which essentially means that both these devices belong to uuid 1 and all other accounts set up on these d1 and d2 devices also map to the same user uuid1. Similarly, d1 also has an account a4 which is also set up on d3, so d3 is also uuid1's device and every account on it should get mapped to uuid1.

How can I achieve the above mentioned transformation in SQL/Spark/GraphFrames (by DataBricks) where both Device-Ids and Account-Ids can be in millions?


Solution

  • You can try GraphFrame.connectedComponents, add a prefix to all Device-IDs, so that they can be split from Account-IDs in the post-processing step:

    from graphframes import GraphFrame
    from pyspark.sql.functions import collect_set, expr
    
    df = spark.createDataFrame([
             ("d1","a1"), ("d2","a1"), ("d1","a2"), ("d1","a4"),
             ("d2","a3"), ("d3","a4"), ("d3","a5"), ("d4","a6")  
    ], ["Device-Id","Account-Id"])
    
    # set checkpoint which is required for Graphframe
    spark.sparkContext.setCheckpointDir("/tmp/111")
    
    # for testing purpose, set a small shuffle partitions
    spark.conf.set("spark.sql.shuffle.partitions", 2)
    
    # set up edges and vertices, add an underscore as prefix of Device-ID
    edges = df.withColumn('Device-Id', expr('concat("_", `Device-Id`)')).toDF('src', 'dst')
    vertices = edges.selectExpr('src as id').distinct().union(edges.select('dst').distinct())
    
    # set up the graph
    g = GraphFrame(vertices, edges)
    
    # compute the connected components and group resultset by component
    # and collect corresponding ids using collect_set(id)
    df1 = g.connectedComponents().groupby('component').agg(collect_set('id').alias('ids'))
    df1.show(truncate=False)
    +------------+-----------------------------------+
    |component   |ids                                |
    +------------+-----------------------------------+
    |309237645312|[a6, _d4]                          |
    |85899345920 |[_d1, a4, a1, _d3, a3, a5, a2, _d2]|
    +------------+-----------------------------------+
    
    # split the ids based on the prefix we predefined when creating edges.
    df1.selectExpr(
          'transform(filter(ids, x -> left(x,1) = "_"), y -> substr(y,2)) AS `Devices-Used`'
        , 'filter(ids, x -> left(x,1) != "_") AS `Accounts-Used`'
        , 'component AS `Unique-User-Id`'
    ).show()
    +------------+--------------------+--------------+
    |Devices-Used|       Accounts-Used|Unique-User-Id|
    +------------+--------------------+--------------+
    |[d1, d3, d2]|[a4, a1, a3, a5, a2]|   85899345920|
    |        [d4]|                [a6]|  309237645312|
    +------------+--------------------+--------------+
    

    Edit: The above method is less efficient in creating large list of edges/vertices which is unnecessary, using self-join to create edges list should be a better choice (inspired by this post):

    edges = df.alias('d1').join(df.alias('d2'), ["Account-Id"]) \
        .filter("d1.`Device-Id` > d2.`Device-Id`") \
        .toDF("account", "src", "dst")
    +-------+---+---+
    |account|src|dst|
    +-------+---+---+
    |     a1| d2| d1|
    |     a4| d3| d1|
    +-------+---+---+
    
    vertices = df.selectExpr('`Device-Id` as id', "`Account-Id` as acct_id")
    g = GraphFrame(vertices, edges)
    
    df1 = g.connectedComponents() \
        .groupby('component') \
        .agg(
           collect_set('id').alias('Device-Ids'),
           collect_set('acct_id').alias('Account-Ids')
         )
    +---------+------------+--------------------+
    |component|  Device-Ids|         Account-Ids|
    +---------+------------+--------------------+
    |        0|[d1, d2, d3]|[a4, a1, a3, a5, a2]|
    |        1|        [d4]|                [a6]|
    +---------+------------+--------------------+