I have the following pyspark dataframe that contains two fields, ID and QUARTER:
pandas_df = pd.DataFrame({"ID":[1, 2, 3,4, 5, 3,5,6,3,7,2,6,8,9,1,7,5,1,10],"QUARTER":[1, 1, 1, 1, 1,2,2,2,3,3,3,3,3,4,4,5,5,5,5]})
spark_df = spark.createDataFrame(pandas_df)
spark_df.createOrReplaceTempView('spark_df')
and I have th following liste that contains the number of entries I want from each of the 5 quarter
numbers=[2,1,3,1,2]
I want to select each time from each quarter a number of rows equals to the number indicated in the list 'numbers'. I should respect that the ID
should be unique at the end. It means if i selected an ID in a certain quarter, I should not reselect it again in an other quarter.
For that I used the following pyspark code:
quart=1 # the first quarter
liste_unique=[] # an empty list that will contains the unique Id values to compare with
for i in range(0,len(numbers)):
tmp=spark_df.where(spark_df.QUARTER==quart)# select only rows with the chosed quarter
tmp=tmp.where(tmp.ID.isin(liste_unique)==False)# the selected id were not selected before
w = Window().partitionBy(lit('col_count0')).orderBy(lit('col_count0'))#dummy column
df_final=tmp.withColumn("row_num", row_number().over(w)).filter(col("row_num").between(1,numbers[i])) # number of rows needed from the 'numbers list'
df_final=df_final.drop(col("row_num")) # drop the row num column
liste_tempo=df_final.select(['ID']).rdd.map(lambda x : x[0]).collect() # transform the selected id into list
liste_unique.extend(liste_tempo) # extend the list of unique id each time we select new rows from a quarter
df0=df0.union(df_final) # union the empty list each time with the selected data in each quarter
quart=quart+1 #increment the quarter
df0 is simply an empty list at the begining. It will contains all the data at the end, it can be declared as follow
spark = SparkSession.builder.appName('Empty_Dataframe').getOrCreate()
# Create an empty schema
columns = StructType([StructField('ID',
StringType(), True),
StructField('QUARTER',
StringType(), True)
])
df0 = spark.createDataFrame(data = [],
schema = columns)
The code works fine without errors, except that I can find duplicate ID at different quarter which is not correct. Also, a weird behavior is When I tried to count the number of unique ID in the df0 dataframe ( in a new different cell)
print(df0.select('ID').distinct().count())
It gives at each execution a different value even if the dataframe is not touched with any other process ( it is more clear with a larger dataset than the example). I can not understand this behavior,I tried to delete the cache or the temporary variables using unpersist(True)
, but nothing change. I suspect that the Union
function is wrongly used but I did not found any alternative in pyspark.
I took a try at cleaning this logic and python code since it was confusing to me. Here goes my version with logic and reasoning behind the statements. In section 1, I created a dataframe and temporary view. Please note, pandas does not need to be used. This code was tested on Azure Databricks.
#
# 1 - create data frame + temporary view
#
# two lists
data1 = [1, 2, 3, 4, 5, 3, 5, 6, 3, 7, 2, 6, 8, 9, 1, 7, 5, 1, 10]
data2 = [1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 5, 5, 5, 5]
# columns
columns1 = ['id', 'quarter']
# creating df
df1 = spark.createDataFrame(zip(data1, data2), columns1)
# create view
df1.createOrReplaceTempView('sample_data')
# show data
display(df1)
The above image shows all 19 records. What is interesting to me is the fact that you have 5 quarters. If you are counting fiscal quarters, there are only 4. In fact, the old code keeps on incrementing given the length of the number array.
The business logic states, pick n(0) ids from quarter 1. Then pick n(1) ids from quarter 2 with replacement, and etc. Thus, the final result will equal to or less than the total number of distinct ids. Again n[] is the input picks array.
If we execute a simple query against the temporary view, we see that the ids 1 to 10 are listed.
To make life easier, I packaged the business logic into a user defined function.
#
# 2 - write function for business logic
#
def pick_logic(num, qtr, ary):
# which qtr to select?
stmt = "select id from sample_data where quarter = '{}'".format(qtr)
# which ids to not repick
if (len(ary) > 0):
stmt = stmt + " and id not in ({})".format(','.join([str(i) for i in ary]))
# how many to limit
stmt = stmt + " order by rand() limit {};".format(num)
# debugging
print(stmt)
# create df
df = spark.sql(stmt)
# convert df -> lst + return
return list(df.select('id').toPandas()['id'])
The logic selects all data from a given quarter. The input array of picked numbers is used to make sure we are choosing new numbers without replacement. I am using an rand() function to order the data before limiting the result.
#
# 3 - call functions with tuples
#
# picks per qtr
picks_per_qtr = [(2,1), (1,2), (3,3), (1,4), (2,5), (3,1)]
# list of ids
i = []
# for each tuple, call business logic
for items in picks_per_qtr:
# choose n values from quarter q
n, q = items
# get an array of new ids
r = pick_logic(n, q, i)
# show picking
print(r)
# append to id list
i = i + r
print("final result is {}".format(i))
I took the liberty to convert the numbers array to an array of tuples (picks, quarter). I left the debugging code in place so that you see the dynamic queries that are being executed and the results that are being returned.
Unlike my first pass at this Spark SQL without the rand() order, one will get different results each time the code is executed. However, the final array with have only x elements equal to or less than the number of distinct ids.
Above image shows that the number 9 has not been picked yet. Also, I was not told if the quarter in which a number was picked needs to be recorded. If that is true, modify the code to return a list of tuples (id, qtr).
In short, the logic is sound and I hope it helps solve you business problem.