I wrote a function I'd like to modify to have an argument that can take one or multiple parameters but I'm having trouble making it work correctly.
def get_recent_date(input_df, *partion_col, order_col):
w = Window().partitionBy(partition_col)\
.orderBy(desc(order_col))
output_df= input_df.withColumn('DenseRank', dense_rank().over(w))
return output_df
I want the function to run so partition_col can take a variable number of parameters. In ex 1 below, partition_col= 'event_category' and in ex 2, partition_col = 'event_category' and 'participant_category'. I've tried running this a variety of ways and often get the error "TypeError: can only concatenate str (not "tuple") to str". Thank you in advance for your help!
ex 1: get_recent_date(input, 'event_category', 'event_date')
ex 2: get_recent_date(input, 'event_category', 'participant_category', 'event_date')
*partion_col
allows you to pass a variable number of non-keyword arguments, where partion_col
will hold all your non-keyword arguments as a tuple
inside your function. Therefore, you will need to unpack your tuple of column names in order for pyspark
to correctly use your variable-length partitioning.
Replace
w = Window().partitionBy(partition_col)\
with
w = Window().partitionBy(*partition_col)\
and you should be good to go.
Replicable example:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import desc, dense_rank
spark = SparkSession.builder.appName('spark_session').getOrCreate()
data = [
(100, 1, 2, 1),
(100, 1, 1, -1),
(200, 1, 3, 1),
(200, 1, 3, 4)
]
df = spark.createDataFrame(data, ("col_1", "col_2", "col_3", 'order_col'))
df.show()
# +-----+-----+-----+---------+
# |col_1|col_2|col_3|order_col|
# +-----+-----+-----+---------+
# | 100| 1| 2| 1|
# | 100| 1| 1| -1|
# | 200| 1| 3| 1|
# | 200| 1| 3| 4|
# +-----+-----+-----+---------+
def get_recent_date(input_df, *partition_col, order_col):
w = Window().partitionBy(*partition_col)\
.orderBy(desc(order_col))
output_df= input_df.withColumn('DenseRank', dense_rank().over(w))
return output_df
new_df = get_recent_date(
df, 'col_2', order_col='order_col'
)
new_df.show()
# +-----+-----+-----+---------+---------+
# |col_1|col_2|col_3|order_col|DenseRank|
# +-----+-----+-----+---------+---------+
# | 200| 1| 3| 4| 1|
# | 100| 1| 2| 1| 2|
# | 200| 1| 3| 1| 2|
# | 100| 1| 1| -1| 3|
# +-----+-----+-----+---------+---------+
new_df = get_recent_date(
df, 'col_2', 'col_1', order_col='order_col'
)
new_df.show()
# |col_1|col_2|col_3|order_col|DenseRank|
# +-----+-----+-----+---------+---------+
# | 100| 1| 2| 1| 1|
# | 100| 1| 1| -1| 2|
# | 200| 1| 3| 4| 1|
# | 200| 1| 3| 1| 2|
# +-----+-----+-----+---------+---------+