Search code examples
arraysapache-sparkpysparkapache-spark-sqlappend

Append to PySpark array column


I want to check if the column values are within some boundaries. If they are not I will append some value to the array column "F". This is the code I have so far:

df = spark.createDataFrame(
    [
        (1, 56), 
        (2, 32),
        (3, 99)
    ],
    ['id', 'some_nr'] 
)

df = df.withColumn( "F", F.lit( None ).cast( types.ArrayType( types.ShortType( ) ) ) )

def boundary_check( val ):
  if (val > 60) | (val < 50):
    return 1

udf  = F.udf( lambda x: boundary_check( x ) ) 

df =  df.withColumn("F", udf(F.col("some_nr")))
display(df)

However, I don't know how to append to the array. Currently, if I perform another boundary check on df it will simply overwrite the previous values in "F"...


Solution

  • Have a look at the array_union function under pyspark.sql.functions here: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=join#pyspark.sql.functions.array_union

    That way you avoid using udf, which takes away any benefits of Spark parallelisation. The code would look something like:

    from pyspark.context import SparkContext
    from pyspark.sql import SparkSession
    from pyspark.conf import SparkConf
    from pyspark.sql import Row
    import pyspark.sql.functions as f
    
    
    conf = SparkConf()
    sc = SparkContext(conf=conf)
    spark = SparkSession(sc)
    
    df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="a", c3=10),
                                Row(c1=["b", "a", "c"], c2="d", c3=20)])
    df.show()
    +---------+---+---+
    |       c1| c2| c3|
    +---------+---+---+
    |[b, a, c]|  a| 10|
    |[b, a, c]|  d| 20|
    +---------+---+---+
    
    df.withColumn(
        "output_column", 
        f.when(f.col("c3") > 10, 
               f.array_union(df.c1, f.array(f.lit("1"))))
         .otherwise(f.col("c1"))
    ).show()
    +---------+---+---+-------------+
    |       c1| c2| c3|output_column|
    +---------+---+---+-------------+
    |[b, a, c]|  a| 10|    [b, a, c]|
    |[b, a, c]|  d| 20| [b, a, c, 1]|
    +---------+---+---+-------------+
    

    As as side note, this works as a logical union, therefore if you want to append a value, you need to make sure this value is unique so that it always gets added. Otherwise, have a look at other array functions here: https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=join#pyspark.sql.functions.array

    NB: Your spark needs to be version >2.4 for most of the array functions.

    EDIT (on request in comments):

    The withColumn method only allows you to work on one column at a time, so you need to use a new withColumn, ideally with predefining your logical statement ahead for both withColumn queries.

    logical_gate = (f.col("c3") > 10)
    
    (
        df.withColumn(
            "output_column", 
            f.when(logical_gate, 
                   f.array_union(df.c1, f.array(f.lit("1"))))
             .otherwise(f.col("c1")))
          .withColumn(
            "c3",
            f.when(logical_gate,
                   f.lit(None))
             .otherwise(f.col("c3")))
          .show()
    )
    +---------+---+----+-------------+
    |       c1| c2|  c3|output_column|
    +---------+---+----+-------------+
    |[b, a, c]|  a|  10|    [b, a, c]|
    |[b, a, c]|  d|null| [b, a, c, 1]|
    +---------+---+----+-------------+