Search code examples
pythonfunctionpysparkargumentsvariadic-functions

How to write Pyspark function that can accept an argument with a variable number of parameters?


I wrote a function I'd like to modify to have an argument that can take one or multiple parameters but I'm having trouble making it work correctly.

def get_recent_date(input_df, *partion_col, order_col):
  w = Window().partitionBy(partition_col)\
  .orderBy(desc(order_col))
  output_df= input_df.withColumn('DenseRank', dense_rank().over(w))
  return output_df

I want the function to run so partition_col can take a variable number of parameters. In ex 1 below, partition_col= 'event_category' and in ex 2, partition_col = 'event_category' and 'participant_category'. I've tried running this a variety of ways and often get the error "TypeError: can only concatenate str (not "tuple") to str". Thank you in advance for your help!

ex 1: get_recent_date(input, 'event_category', 'event_date')

ex 2: get_recent_date(input, 'event_category', 'participant_category', 'event_date')


Solution

  • *partion_col allows you to pass a variable number of non-keyword arguments, where partion_col will hold all your non-keyword arguments as a tuple inside your function. Therefore, you will need to unpack your tuple of column names in order for pyspark to correctly use your variable-length partitioning.

    Replace

    w = Window().partitionBy(partition_col)\
    

    with

    w = Window().partitionBy(*partition_col)\
    

    and you should be good to go.

    Replicable example:

    from pyspark.sql import SparkSession, Window
    from pyspark.sql.functions import desc, dense_rank
    
    spark = SparkSession.builder.appName('spark_session').getOrCreate()
    
    data = [
        (100, 1, 2, 1),
        (100, 1, 1, -1),
        (200, 1, 3, 1),
        (200, 1, 3, 4)   
    ]
    
    df = spark.createDataFrame(data, ("col_1", "col_2", "col_3", 'order_col'))
    
    df.show()
    
    # +-----+-----+-----+---------+
    # |col_1|col_2|col_3|order_col|
    # +-----+-----+-----+---------+
    # |  100|    1|    2|        1|
    # |  100|    1|    1|       -1|
    # |  200|    1|    3|        1|
    # |  200|    1|    3|        4|
    # +-----+-----+-----+---------+
    
    def get_recent_date(input_df, *partition_col, order_col):
        w = Window().partitionBy(*partition_col)\
        .orderBy(desc(order_col))
        output_df= input_df.withColumn('DenseRank', dense_rank().over(w))
        return output_df
    
    
    new_df = get_recent_date(
        df, 'col_2', order_col='order_col'
    )
    
    new_df.show()
    
    # +-----+-----+-----+---------+---------+
    # |col_1|col_2|col_3|order_col|DenseRank|
    # +-----+-----+-----+---------+---------+
    # |  200|    1|    3|        4|        1|
    # |  100|    1|    2|        1|        2|
    # |  200|    1|    3|        1|        2|
    # |  100|    1|    1|       -1|        3|
    # +-----+-----+-----+---------+---------+
    
    new_df = get_recent_date(
        df, 'col_2', 'col_1', order_col='order_col'
    )
    
    new_df.show()
    
    # |col_1|col_2|col_3|order_col|DenseRank|
    # +-----+-----+-----+---------+---------+
    # |  100|    1|    2|        1|        1|
    # |  100|    1|    1|       -1|        2|
    # |  200|    1|    3|        4|        1|
    # |  200|    1|    3|        1|        2|
    # +-----+-----+-----+---------+---------+