Search code examples
pythonpandasapache-sparkpysparkapache-spark-sql

how to flatten a nested, mixed array of structs in pyspark?


Consider the following schema in a PySpark dataframe df:

root
 |-- mydoc: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Driver: long (nullable = true)
 |    |    |-- Information: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- Name: string (nullable = true)
 |    |    |    |    |-- Id: long (nullable = true)
 |    |    |    |    |-- Car: string (nullable = true)
 |    |    |    |    |-- Age: long (nullable = true)

I would like to flatten the Information array of structs so that it appears in my PySpark dataframe as

flatName     flatId  flatCar         flatAge
"john,mike"  "1,2"   "ferrari,polo"  "12,24"

As you can see, I simply want to express each element as a string delimited by ,. I tried various tricks such as

df.select(array_join(df.mydoc.Information.Name,','))

With no success. Any ideas?

Thanks!


Solution

  • You can use a combination of explode, ".*" and groupBy, here's a code:

    df = df.withColumn("mydoc", explode("mydoc")).select("mydoc.*").withColumn("Information", explode("Information")) \
        .select("Information.*").groupby().agg(
        concat_ws(",", collect_list("Name")).alias("flatName"),
        concat_ws(",", collect_list("Age")).alias("flatAge"),
        concat_ws(",", collect_list("Car")).alias("flatCar"),
        concat_ws(",", collect_list("Id")).alias("flatId")
    )
    df.show(truncate=False)
    

    Result:

    +---------+-------+------------+------+
    |flatName |flatAge|flatCar     |flatId|
    +---------+-------+------------+------+
    |john,mike|12,24  |ferrari,polo|1,2   |
    +---------+-------+------------+------+
    

    UPDATE:

    If you want it to be done to all columns, then you can do it with a for loop:

    df = df.withColumn("mydoc", explode("mydoc")).select("mydoc.*").withColumn("Information", explode("Information")).select("Information.*")
    df = df.groupby().agg(*[concat_ws(",", collect_list(col)).alias(f"flat{col}") for col in df.columns])