Search code examples
dataframeapache-sparkpyspark

How can I create a new field in Pyspark using withColumn, for loop, and UDF?


I have a slightly complex case of logic in a Pyspark dataframe. I need to create a new field with many fields as input. Given this dataframe:

df = spark.createDataFrame(
[(1, 100, 100, 'A', 'A'),
  (2, 1000, 200, 'A', 'A'),
  (3, 1000, 300, 'B', 'A'),
  (4, 1000, 1000, 'B', 'B')],  
"id int, days1 int, days2 int, code1 string, code2 string")

df.show()

+---+-----+-----+-----+-----+
| id|days1|days2|code1|code2|
+---+-----+-----+-----+-----+
|  1|  100|  100|    A|    A|
|  2| 1000|  200|    A|    A|
|  3| 1000|  300|    B|    A|
|  4| 1000| 1000|    B|    B|
+---+-----+-----+-----+-----+

I need to add a new column that sums the number of occurrences: where daysN > 500 and codeN = 'B', then add 1 to sum_fields.

For example, the combinations of 1000 and 'B' satisfy the logic so 1 is added to sum_fields. Else, 0.

+---+-----+-----+-----+-----+----------+
| id|days1|days2|code1|code2|sum_fields|
+---+-----+-----+-----+-----+----------+
|  1|  100|  100|    A|    A|         0|
|  2| 1000|  200|    A|    A|         0|
|  3| 1000|  300|    B|    A|         1|
|  4| 1000| 1000|    B|    B|         2|
+---+-----+-----+-----+-----+----------+

I did this which worked, but manually specifying the fields. I need to do this for 30 fields.

def udf_test(x, y):
  cnt = 0
  if x > 500 and y == 'B':
    cnt += 1
  
  return cnt

myUDF = F.udf(udf_test, IntegerType())

df.withColumn("sum_fields", myUDF("diff1", "code1")).display()

I know there's list comprehension for select. How can I apply this for loop to withColumn and the logic above?

df.select(*[F.col(f'days{i+1}') for i in range(30)])

Solution

  • Use a list comprehension containing one entry for each days/code pair. Each entry consists of a when statement like

    F.when('daysX'>500 & 'codeX'=='B', 1).otherwise(0)
    

    Then use reduce to sum up all the entries.

    from pyspark.sql import functions as F
    from functools import reduce
    
    cols=[F.when((F.col(f'days{c[4:]}')>500) & (F.col(f'code{c[4:]}')==F.lit('B')), 1)
          .otherwise(0) 
             for c in df.columns if c.startswith('days')]
    sum=reduce(lambda x,y: x+y, cols)
    
    df.withColumn('sum_fields', sum).show()
    

    Result:

    +---+-----+-----+-----+-----+----------+
    | id|days1|days2|code1|code2|sum_fields|
    +---+-----+-----+-----+-----+----------+
    |  1|  100|  100|    A|    A|         0|
    |  2| 1000|  200|    A|    A|         0|
    |  3| 1000|  300|    B|    A|         1|
    |  4| 1000| 1000|    B|    B|         2|
    +---+-----+-----+-----+-----+----------+```