Search code examples
apache-spark-sqlcommon-table-expressionhierarchydatabricks

Flatten hierarchy table using PySpark


I am trying to flatten a hierarchy table with a fixed number of levels using Pyspark 3.0 (Databricks)

Input Data:

sc.parallelize([[0,None,'root','?'],[1,0,'a','aaaaa'], [2,1,'b','bbbbb'],[3,1,'c','ccccc'],[4,3,'d','ddddd'],[5,4,'e','eeeee']]).toDF(("id", "parent", "name", "attribute")).createOrReplaceTempView('df')

input

Required Output:

output

I have the output required output using CTE but there must be a more concise way of coding this, any help appreciated:

WITH lvl0
AS (
    SELECT 0 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute AS lvl0_attribute
    FROM df
    WHERE df.parent IS NULL
    ), lvl1
AS (
    SELECT 1 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute
    FROM df
    JOIN lvl0 ON lvl0.id = df.parent
    ), lvl2
AS (
    SELECT 2 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute
    FROM df
    JOIN lvl1 ON lvl1.id = df.parent
    ), lvl3
AS (
    SELECT 3 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute
    FROM df
    JOIN lvl2 ON lvl2.id = df.parent
    )
SELECT lvl0.LEAF_LEVEL, lvl0.id AS lvl0_id, NULL AS lvl1_id, NULL AS lvl1_name, NULL AS lvl1_attribute, NULL AS lvl2_id, NULL AS lvl2_name, NULL AS lvl2_attribute, NULL AS lvl3_id, NULL AS lvl3_name, NULL AS lvl3_attribute
FROM lvl0

UNION

SELECT lvl1.LEAF_LEVEL, lvl0.id AS lvl0_id, lvl1.id AS lvl1_id, lvl1.name AS lvl1_name, lvl1.attribute AS lvl1_attribute, NULL AS lvl2_id, NULL AS lvl2_name, NULL AS lvl2_attribute, NULL AS lvl3_id, NULL AS lvl3_name, NULL AS lvl3_attribute
FROM lvl0
INNER JOIN lvl1 ON lvl1.parent = lvl0.id

UNION

SELECT lvl2.LEAF_LEVEL, lvl0.id AS lvl0_id, lvl1.id AS lvl1_id, lvl1.name AS lvl1_name, lvl1.attribute AS lvl1_attribute, lvl2.id AS lvl2_id, lvl2.name AS lvl2_name, lvl2.attribute AS lvl2_attribute, NULL AS lvl3_id, NULL AS lvl3_name, NULL AS lvl3_attribute
FROM lvl0
INNER JOIN lvl1 ON lvl1.parent = lvl0.id
INNER JOIN lvl2 ON lvl2.parent = lvl1.id

UNION

SELECT lvl3.LEAF_LEVEL, lvl0.id AS lvl0_id, lvl1.id AS lvl1_id, lvl1.name AS lvl1_name, lvl1.attribute AS lvl1_attribute, lvl2.id AS lvl2_id, lvl2.name AS lvl2_name, lvl2.attribute AS lvl2_attribute, lvl3.id AS lvl3_id, lvl3.name AS lvl3_name, lvl3.attribute AS lvl3_attribute
FROM lvl0
INNER JOIN lvl1 ON lvl1.parent = lvl0.id
INNER JOIN lvl2 ON lvl2.parent = lvl1.id
INNER JOIN lvl3 ON lvl3.parent = lvl2.id
ORDER BY 1, 2, 3, 4, 5, 6 ASC

Solution

  • Using pyspark you can write this in more generic way, so it will be more concise.

    For each level join data from next level and union with current level data with extra columns. You can do this with add_next_level function defined below and reduce from functools.

    data = sc.parallelize([[0,None,'root','?'],[1,0,'a','aaaaa'], [2,1,'b','bbbbb'],[3,1,'c','ccccc'],[4,3,'d','ddddd'],[5,4,'e','eeeee']]).toDF(("id", "parent", "name", "attribute"))
    
    
    def add_next_level(df, level):
        return df.join( #join with next level
                data.select(
                    lit(level).alias('next_level'),
                    col('parent'),
                    col('id').alias(f'lvl{level}_id'),
                    col('name').alias(f'lvl{level}_name'),
                    col('attribute').alias(f'lvl{level}_attribute')
                ), col(f'lvl{level - 1}_id') == col('parent'), 'left') \
            .withColumn('LEAF_LEVEL', coalesce(col('next_level'), col('LEAF_LEVEL'))) \
            .drop('parent', 'next_level') \
            .union( #union to keep data from current level
                df.join(data, col(f'lvl{level - 1}_id') == col('parent'), 'left_semi')
                    .where(col(f'LEAF_LEVEL') == lit(level - 1))
                    .withColumn(f'lvl{level}_id', lit(None))
                    .withColumn(f'lvl{level}_name', lit(None))
                    .withColumn(f'lvl{level}_attribute', lit(None)))
    
    reduce(
        lambda df, num: add_next_level(df, num),
        [1, 2, 3],
        spark.createDataFrame([Row(0, 0)], ['LEAF_LEVEL', 'lvl0_id'])
    ).sort('lvl1_id', 'lvl2_id', 'lvl3_id').show()