Search code examples
pythonpysparkapache-spark-sql

Removing columns in a nested struct in a spark dataframe using PySpark (details in text)


I know that I've asked a similar question here but that was for row filtering. This time I am trying to drop columns instead. I tried to implement Higher Order Functions such as FILTER and others for a while but could not get it to work. I think what I need is a SELECT Higher Order Function but it doesn't seem to exist.

I am using pyspark and I have a dataframe object df and this is what the output of df.printSchema() looks like

root
 |-- M_MRN: string (nullable = true)
 |-- measurements: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Observation_ID: string (nullable = true)
 |    |    |-- Observation_Name: string (nullable = true)
 |    |    |-- Observation_Result: string (nullable = true)

I would like to keep only the 'Observation_ID' or 'Observation_Result' columns in 'measurements'. So currently when I run df.select('measurements').take(2) I get

[Row(measurements=[Row(Observation_ID='5', Observation_Name='ABC', Observation_Result='108/72'),
                   Row(Observation_ID='11', Observation_Name='ABC', Observation_Result='70'),
                   Row(Observation_ID='10', Observation_Name='ABC', Observation_Result='73.029'),
                   Row(Observation_ID='14', Observation_Name='XYZ', Observation_Result='23.1')]),
 Row(measurements=[Row(Observation_ID='2', Observation_Name='ZZZ', Observation_Result='3/4'),
                   Row(Observation_ID='5', Observation_Name='ABC', Observation_Result='7')])]

I would like that after I do the above filtering and run df.select('measurements').take(2) I get

[Row(measurements=[Row(Observation_ID='5', Observation_Result='108/72'),
                   Row(Observation_ID='11', Observation_Result='70'),
                   Row(Observation_ID='10', Observation_Result='73.029'),
                   Row(Observation_ID='14', Observation_Result='23.1')]),
 Row(measurements=[Row(Observation_ID='2', Observation_Result='3/4'),
                   Row(Observation_ID='5', Observation_Result='7')])]

Is there a way to do this in pyspark?


Solution

  • You could use higher order function transform to select your desired fields and put them in a struct.

    from pyspark.sql import functions as F
    df.withColumn("measurements",F.expr("""transform(measurements\
    ,x-> struct(x.Observation_ID as Observation_ID,\
                 x.Observation_Result as Observation_Result))""")).printSchema()
    
    #root
     #|-- measurements: array (nullable = true)
     #|    |-- element: struct (containsNull = false)
     #|    |    |-- Observation_ID: string (nullable = true)
     #|    |    |-- Observation_Result: string (nullable = true)