Search code examples
pyspark

Modify values of array columns


I have two columns in a df and I want to transform some of their values based on some conditions. These columns are outcomes of collect_set aggregations so I have the option to apply some transformation beforehand without needing to explode again.

More specifically let's say the Dataframe looks like this table:

df = [
 (1,["0","1","2"],["10"]),
 (2,["0"],["20"]),
 (3,["3"],[null])
]

I want to transform it so:

  1. if arr1 column has more than 1 elements, remove the '0' if it exists. if '0' is the only element then keep it.
  2. if arr2 is has no elements, then pass a default value to it.

I tried to apply size(col1) to these columns but I am getting a Column not iterable error. exploding again the arrays and collecting their sets their options is really not an option to me for performance reasons so I have to either be creative beforehand or apply the transformations somehow directly to the arrays but in a cost efficient way.

That being said, the wished outcome will be:

df = [
 (1,["1","2"],["10"]),
 (2,["0"],["20"]),
 (3,["3"],["default_value"])
]

ordering is not important


Solution

  • use size,higher order array functions to filter out the 0 from array elements without exploding.

    Example:

    df = spark.createDataFrame([(1,["0","1","2"],["10"]),(2,["0"],["20"]),(3,["3"],[None])],['id','arr','arr1'])
    df.withColumn('arr',when(size(col("arr"))>1,expr("filter(arr, x -> x != 0)")).otherwise(col("arr"))).show(10,False)
    #+---+------+------+
    #|id |arr   |arr1  |
    #+---+------+------+
    #|1  |[1, 2]|[10]  |
    #|2  |[0]   |[20]  |
    #|3  |[3]   |[null]|
    #+---+------+------+