Search code examples
arraysapache-sparkpysparkapache-spark-sqlnonetype

Adding None to PySpark array


I want to create an array which is conditionally populated based off of existing column and sometimes I want it to contain None. Here's some example code:

from pyspark.sql import Row
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, array, lit
 
spark = SparkSession.builder.getOrCreate()
 
df = spark.createDataFrame([
    Row(ID=1),
    Row(ID=2),
    Row(ID=2),
    Row(ID=1)
])

value_lit = 0.45
size = 10

df = df.withColumn("TEST",when(df["ID"] == 2,array([None for i in range(size)])).otherwise(array([lit(value_lit) for i in range(size)])))

df.show(truncate=False)

And here's the error I'm getting:

TypeError: Invalid argument, not a string or column: None of type <type 'NoneType'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

I know it isn't a string or column, I don't see why it has to be?

  • lit: doesn't work.
  • array: I'm not sure how to use array in this context.
  • struct: probably the way to go but I'm not sure how to use it here. Perhaps I have to set an option to allow the new column to contain None values?
  • create_map: I'm not creating a key:value map so I'm sure this is not the correct one to use.

Solution

  • The condition must be flipped: F.when(F.col('ID') != 2, value_lit)

    If you do it, you don't need otherwise at all. If when condition is not satisfied, the result is always null.

    Also, just one list comprehension is enough.

    from pyspark.sql import SparkSession, functions as F
    spark = SparkSession.builder.getOrCreate()
     
    df = spark.createDataFrame([(1,), (2,), (2,), (1,)], ['ID'])
    
    value_lit = 0.45
    size = 10
    
    df = df.withColumn("TEST", F.array([F.when(F.col('ID') != 2, value_lit) for i in range(size)]))
    
    df.show(truncate=False)
    # +---+------------------------------------------------------------+
    # |ID |TEST                                                        |
    # +---+------------------------------------------------------------+
    # |1  |[0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45]|
    # |2  |[,,,,,,,,,]                                                 |
    # |2  |[,,,,,,,,,]                                                 |
    # |1  |[0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45]|
    # +---+------------------------------------------------------------+
    

    I've run this code on Spark 2.4.3.