Search code examples
dataframescalaapache-sparkapache-spark-sqlaws-glue

Extract Nested Array of Struct with If Else Logic


I need to derive two new fields from the below schema structure - new columns beaver_id and llama_id. There is some if else logic that needs to be applied to an array of struct. The desired end result is a csv output. What is the best approach for this?

Schema:

root
 |-- Animal: struct (nullable = true)
 |    |-- Species: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- name: string (nullable = true)
 |    |    |    |-- color: string (nullable = true)
 |    |    |    |-- unique_id: string (nullable = true)

Pseudo Code:

If name == "Beaver" 
   then get unique_id and put in dataframe column "beaver_id"
     else
   null in column "beaver_id"

If name == "Llama"
   then get unique_id and put in dataframe column "llama_id"
     else
   null in column "llama_id"

If array of names does not contain "Llama" or "Beaver"
   then null for both "beaver_id" and "llama_id"

Currently: I am using select dataframe function to select elements out of the input (parquet) to create a csv output. I am extracting many other elements using this approach besides the ones in this question.

var select_df = raw_df.select(
  col(Animal.Species.name).getField("name")
  col(Animal.Species.color).getField("color")
)

Example Input (in JSON), Actual input is parquet:

{
  "Animal": {
    "Species": [
      {
        "name": "Beaver",
        "color": "red",
        "unique_id": "1001"
      },
      {
        "name": "Llama",
        "color": "blue",
        "unique_id": "2222"
      }
    ]
  }
}

Expected csv output:

beaver_id, llama_id
1001, 2222

Solution

  • You can use filter function on Animal.Species array column like this:

    val select_df = raw_df.select(
      element_at(expr("filter(Animal.Species, x -> x.name = 'Beaver')"), 1)
        .getField("unique_id")
        .as("beaver_id"),
      element_at(expr("filter(Animal.Species, x -> x.name = 'Llama')"), 1)
        .getField("unique_id")
        .as("llama_id")
    )
    
    select_df.show
    //+---------+--------+
    //|beaver_id|llama_id|
    //+---------+--------+
    //|     1001|    2222|
    //+---------+--------+
    

    The logic is quite simple, we filter the array to find the inner struct which holds name = Beaver|Llama and get its unique_id. If not found null is returned.

    Note that since Spark 3, you can also use the filter function within DataFrame API.