Search code examples
apache-sparkpysparkapache-spark-sqlconditional-statementscase

Result of a when chain in Spark


I have a chained when condition in a Spark DataFrame which looks something like this:

df = df.withColumn("some_column", when((lower(df.transaction_id) == "id1") & (df.some_qty != 0), df.some_qty)
                                 .when((lower(df.transaction_id) == "id1") & (df.some_qty == 0) & (df.some_qty2 != 0), df.some_qty2)
                                 .when((lower(df.transaction_id) == "id1") & (df.some_qty == 0) & (df.some_qty2 == 0), 0)
                                 .when((lower(df.transaction_id) == "id2") & (df.some_qty3 != 0), df.some_qty3)
                                 .when((lower(df.transaction_id) == "id2") & (df.some_qty3 == 0) & (df.some_qty4 != 0), df.some_qty4)
                                 .when((lower(df.transaction_id) == "id2") & (df.some_qty3 == 0) & (df.some_qty4 == 0), 0))

In the expression, I'm trying to modify the value of a column based on the values of other columns. I wanted to understand the execution of the above statement. As in, are all the conditions checked for every row of the dataframe and if yes what happens when more than one when condition is true. Or is it the case the the order of chain is followed and the first one to be true is used?


Solution

  • Yes, every row is going to be checked. But spark takes care of optimizing that, so it's not like looping for each cell.

    As for the order, with an exemple we can see that the first one is taken into account:

    df = spark.createDataFrame(
        [
        ('id2','70.07','22.1','0','1'),
        ('id1','0','0','1','3'),
        ('id2','80.7','0','1','3'),
        ('id2','0','0','1','3'),
        ('id1','22.2','0','1','3')
        ],
        ['transaction_id','some_qty','some_qty2', 'some_qty3','some_qty4']
    )\
        .withColumn('some_qty', F.col('some_qty').cast('double'))\
        .withColumn('some_qty2', F.col('some_qty2').cast('double'))\
        .withColumn('some_qty3', F.col('some_qty3').cast('double'))\
        .withColumn('some_qty4', F.col('some_qty4').cast('double'))\
    
    from pyspark.sql.functions import when, lower, lit
    
    df = df.withColumn("some_column",when((lower(df.transaction_id) == "id1") & (df.some_qty != 0),lit('first_true'))
                                     .when((lower(df.transaction_id) == "id1") & (df.some_qty != 0),lit('second_true')))
    df.show()
    
    # +--------------+--------+---------+---------+---------+-----------+
    # |transaction_id|some_qty|some_qty2|some_qty3|some_qty4|some_column|
    # +--------------+--------+---------+---------+---------+-----------+
    # |           id2|   70.07|     22.1|      0.0|      1.0|       null|
    # |           id1|     0.0|      0.0|      1.0|      3.0|       null|
    # |           id2|    80.7|      0.0|      1.0|      3.0|       null|
    # |           id2|     0.0|      0.0|      1.0|      3.0|       null|
    # |           id1|    22.2|      0.0|      1.0|      3.0| first_true|
    # +--------------+--------+---------+---------+---------+-----------+