I am trying to flatten a hierarchy table with a fixed number of levels using Pyspark 3.0 (Databricks)
Input Data:
sc.parallelize([[0,None,'root','?'],[1,0,'a','aaaaa'], [2,1,'b','bbbbb'],[3,1,'c','ccccc'],[4,3,'d','ddddd'],[5,4,'e','eeeee']]).toDF(("id", "parent", "name", "attribute")).createOrReplaceTempView('df')
Required Output:
I have the output required output using CTE but there must be a more concise way of coding this, any help appreciated:
WITH lvl0
AS (
SELECT 0 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute AS lvl0_attribute
FROM df
WHERE df.parent IS NULL
), lvl1
AS (
SELECT 1 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute
FROM df
JOIN lvl0 ON lvl0.id = df.parent
), lvl2
AS (
SELECT 2 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute
FROM df
JOIN lvl1 ON lvl1.id = df.parent
), lvl3
AS (
SELECT 3 AS LEAF_LEVEL, df.id, df.parent, df.name, df.attribute
FROM df
JOIN lvl2 ON lvl2.id = df.parent
)
SELECT lvl0.LEAF_LEVEL, lvl0.id AS lvl0_id, NULL AS lvl1_id, NULL AS lvl1_name, NULL AS lvl1_attribute, NULL AS lvl2_id, NULL AS lvl2_name, NULL AS lvl2_attribute, NULL AS lvl3_id, NULL AS lvl3_name, NULL AS lvl3_attribute
FROM lvl0
UNION
SELECT lvl1.LEAF_LEVEL, lvl0.id AS lvl0_id, lvl1.id AS lvl1_id, lvl1.name AS lvl1_name, lvl1.attribute AS lvl1_attribute, NULL AS lvl2_id, NULL AS lvl2_name, NULL AS lvl2_attribute, NULL AS lvl3_id, NULL AS lvl3_name, NULL AS lvl3_attribute
FROM lvl0
INNER JOIN lvl1 ON lvl1.parent = lvl0.id
UNION
SELECT lvl2.LEAF_LEVEL, lvl0.id AS lvl0_id, lvl1.id AS lvl1_id, lvl1.name AS lvl1_name, lvl1.attribute AS lvl1_attribute, lvl2.id AS lvl2_id, lvl2.name AS lvl2_name, lvl2.attribute AS lvl2_attribute, NULL AS lvl3_id, NULL AS lvl3_name, NULL AS lvl3_attribute
FROM lvl0
INNER JOIN lvl1 ON lvl1.parent = lvl0.id
INNER JOIN lvl2 ON lvl2.parent = lvl1.id
UNION
SELECT lvl3.LEAF_LEVEL, lvl0.id AS lvl0_id, lvl1.id AS lvl1_id, lvl1.name AS lvl1_name, lvl1.attribute AS lvl1_attribute, lvl2.id AS lvl2_id, lvl2.name AS lvl2_name, lvl2.attribute AS lvl2_attribute, lvl3.id AS lvl3_id, lvl3.name AS lvl3_name, lvl3.attribute AS lvl3_attribute
FROM lvl0
INNER JOIN lvl1 ON lvl1.parent = lvl0.id
INNER JOIN lvl2 ON lvl2.parent = lvl1.id
INNER JOIN lvl3 ON lvl3.parent = lvl2.id
ORDER BY 1, 2, 3, 4, 5, 6 ASC
Using pyspark you can write this in more generic way, so it will be more concise.
For each level join data from next level and union with current level data with extra columns. You can do this with add_next_level
function defined below and reduce
from functools
.
data = sc.parallelize([[0,None,'root','?'],[1,0,'a','aaaaa'], [2,1,'b','bbbbb'],[3,1,'c','ccccc'],[4,3,'d','ddddd'],[5,4,'e','eeeee']]).toDF(("id", "parent", "name", "attribute"))
def add_next_level(df, level):
return df.join( #join with next level
data.select(
lit(level).alias('next_level'),
col('parent'),
col('id').alias(f'lvl{level}_id'),
col('name').alias(f'lvl{level}_name'),
col('attribute').alias(f'lvl{level}_attribute')
), col(f'lvl{level - 1}_id') == col('parent'), 'left') \
.withColumn('LEAF_LEVEL', coalesce(col('next_level'), col('LEAF_LEVEL'))) \
.drop('parent', 'next_level') \
.union( #union to keep data from current level
df.join(data, col(f'lvl{level - 1}_id') == col('parent'), 'left_semi')
.where(col(f'LEAF_LEVEL') == lit(level - 1))
.withColumn(f'lvl{level}_id', lit(None))
.withColumn(f'lvl{level}_name', lit(None))
.withColumn(f'lvl{level}_attribute', lit(None)))
reduce(
lambda df, num: add_next_level(df, num),
[1, 2, 3],
spark.createDataFrame([Row(0, 0)], ['LEAF_LEVEL', 'lvl0_id'])
).sort('lvl1_id', 'lvl2_id', 'lvl3_id').show()