Search code examples
apache-sparkpysparkapache-spark-sqluser-defined-functions

Iterate over an array in a pyspark dataframe, and create a new column based on columns of the same name as the values in the array


I have a table in this format:

name fruits apple banana orange
Alice ["apple","banana","orange"] 5 8 3
Bob ["apple"] 2 9 1

I want to make a new column that contains a JSON package in this format, where the key is the element of the array, and the value is the resulting value of the name of the column:

name fruits apple banana orange new_col
Alice ["apple","banana","orange"] 5 8 3 {"apple":5, "banana":8, "orange":3}
Bob ["apple"] 2 9 1 {"apple":2}

Any thoughts on how to proceed? I'm assuming a UDF, but I can't get the right syntax.

This is as far as I've got with the code:

from pyspark.sql.functions import udf, col
from pyspark.sql.types import MapType, StringType

# Create a Spark session
spark = SparkSession.builder.appName("example").getOrCreate()

# Sample data
data = [("Alice", ["apple", "banana", "orange"], 5, 8, 3),
        ("Bob", ["apple"], 2, 9, 1)]

# Define the schema
schema = ["name", "fruits", "apple", "banana", "orange"]

# Create a DataFrame
df = spark.createDataFrame(data, schema=schema)

# Show the initial DataFrame
print("Initial DataFrame:")
display(df)

# Define a UDF to create a dictionary
@udf(MapType(StringType(), StringType()))
def json_map(fruits):
    result = {}
    for i in fruits:
        result[i] = col(i)
    return result

# Apply the UDF to the 'fruits' column
new_df = df.withColumn('test', json_map(col('fruits')))

# Display the updated DataFrame
display(new_df)

Solution

  • you could use the arrays_zip method shared by Abdennacer in his answer, but the prerequisite is that the array elements should align with your columns which might not be the case always.

    another approach is to create an array of maps for the columns and filter the array to retain only the key-value pairs for keys that are available in the fruits array.

    here's an example

    # i've changed the input slightly to rearrange the fruits array
    # +-----+-----------------------+-----+------+------+
    # |name |fruits                 |apple|banana|orange|
    # +-----+-----------------------+-----+------+------+
    # |Alice|[orange, banana, apple]|5    |8     |3     |
    # |Bob  |[apple]                |2    |9     |1     |
    # +-----+-----------------------+-----+------+------+
    
    data_sdf. \
        withColumn('fruitcols_arr', 
                   func.array(*[func.create_map([func.lit(c), func.col(c)]) for c in data_sdf.drop('name', 'fruits').columns])
                   ). \
        withColumn('fruitcols_arr', 
                   func.expr('filter(fruitcols_arr, x -> array_contains(fruits, map_keys(x)[0]))')
                   ). \
        withColumn('new_col',
                   func.aggregate(func.expr('slice(fruitcols_arr, 2, size(fruitcols_arr))'),
                                  func.col('fruitcols_arr')[0],
                                  lambda x, y: func.map_concat(x, y)
                                  )
                   ). \
        drop('fruitcols_arr'). \
        show(truncate=False)
    
    # +-----+-----------------------+-----+------+------+--------------------------------------+
    # |name |fruits                 |apple|banana|orange|new_col                               |
    # +-----+-----------------------+-----+------+------+--------------------------------------+
    # |Alice|[orange, banana, apple]|5    |8     |3     |{apple -> 5, banana -> 8, orange -> 3}|
    # |Bob  |[apple]                |2    |9     |1     |{apple -> 2}                          |
    # +-----+-----------------------+-----+------+------+--------------------------------------+
    

    the first fruitcols_arr creates an array of maps (column_name -> column_value) using each of the individual fruit columns. the second one filters the array based on the fruits column array elements. this is where you get your final array of maps based on each of the fruits element. new_col is created by using the aggregate higher order function with a map_concat that concatenates all the individual maps within the final array of maps.