Search code examples
pythonapache-sparkpysparkunionpyspark-schema

Weird behaviour in Pyspark dataframe


I have the following pyspark dataframe that contains two fields, ID and QUARTER:

pandas_df = pd.DataFrame({"ID":[1, 2, 3,4, 5, 3,5,6,3,7,2,6,8,9,1,7,5,1,10],"QUARTER":[1, 1, 1, 1, 1,2,2,2,3,3,3,3,3,4,4,5,5,5,5]})
spark_df = spark.createDataFrame(pandas_df)
spark_df.createOrReplaceTempView('spark_df')

and I have th following liste that contains the number of entries I want from each of the 5 quarter

numbers=[2,1,3,1,2]

I want to select each time from each quarter a number of rows equals to the number indicated in the list 'numbers'. I should respect that the ID should be unique at the end. It means if i selected an ID in a certain quarter, I should not reselect it again in an other quarter.

For that I used the following pyspark code:


quart=1 # the first quarter
liste_unique=[] # an empty list that will contains the unique Id values to compare with
for i in range(0,len(numbers)):
  tmp=spark_df.where(spark_df.QUARTER==quart)# select only rows with the chosed quarter
  tmp=tmp.where(tmp.ID.isin(liste_unique)==False)# the selected id were not selected before
  w = Window().partitionBy(lit('col_count0')).orderBy(lit('col_count0'))#dummy column
  df_final=tmp.withColumn("row_num", row_number().over(w)).filter(col("row_num").between(1,numbers[i])) # number of rows needed from the 'numbers list'
  df_final=df_final.drop(col("row_num")) # drop the row num column
  liste_tempo=df_final.select(['ID']).rdd.map(lambda x : x[0]).collect() # transform the selected  id into list 

 liste_unique.extend(liste_tempo) # extend the list of unique id each time we select new rows from a quarter
  
  df0=df0.union(df_final) # union the empty list each time with the selected data in each quarter
  
  quart=quart+1 #increment the quarter

df0 is simply an empty list at the begining. It will contains all the data at the end, it can be declared as follow

spark = SparkSession.builder.appName('Empty_Dataframe').getOrCreate()
 
# Create an empty schema

columns = StructType([StructField('ID',
                                  StringType(), True),
                    StructField('QUARTER',
                                StringType(), True)
                      ])

df0 = spark.createDataFrame(data = [],
                           schema = columns)

The code works fine without errors, except that I can find duplicate ID at different quarter which is not correct. Also, a weird behavior is When I tried to count the number of unique ID in the df0 dataframe ( in a new different cell)

print(df0.select('ID').distinct().count())

It gives at each execution a different value even if the dataframe is not touched with any other process ( it is more clear with a larger dataset than the example). I can not understand this behavior,I tried to delete the cache or the temporary variables using unpersist(True), but nothing change. I suspect that the Union function is wrongly used but I did not found any alternative in pyspark.


Solution

  • I took a try at cleaning this logic and python code since it was confusing to me. Here goes my version with logic and reasoning behind the statements. In section 1, I created a dataframe and temporary view. Please note, pandas does not need to be used. This code was tested on Azure Databricks.

    #
    # 1 - create data frame + temporary view
    #
    
    # two lists
    data1 = [1, 2, 3, 4, 5, 3, 5, 6, 3, 7, 2, 6, 8, 9, 1, 7, 5, 1, 10]
    data2 = [1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 5, 5, 5, 5]
      
    # columns
    columns1 = ['id', 'quarter']
      
    # creating df
    df1 = spark.createDataFrame(zip(data1, data2), columns1)
    
    # create view
    df1.createOrReplaceTempView('sample_data')
    
    # show data
    display(df1)
    

    enter image description here

    The above image shows all 19 records. What is interesting to me is the fact that you have 5 quarters. If you are counting fiscal quarters, there are only 4. In fact, the old code keeps on incrementing given the length of the number array.

    The business logic states, pick n(0) ids from quarter 1. Then pick n(1) ids from quarter 2 with replacement, and etc. Thus, the final result will equal to or less than the total number of distinct ids. Again n[] is the input picks array.

    enter image description here

    If we execute a simple query against the temporary view, we see that the ids 1 to 10 are listed.

    To make life easier, I packaged the business logic into a user defined function.

    #
    # 2 - write function for business logic
    #
    
    def pick_logic(num, qtr, ary):
      
      # which qtr to select?
      stmt = "select id from sample_data where quarter = '{}'".format(qtr)
      
      # which ids to not repick
      if (len(ary) > 0):
        stmt = stmt + " and id not in ({})".format(','.join([str(i) for i in ary]))
    
      # how many to limit
      stmt = stmt + " order by rand() limit {};".format(num)
       
      # debugging
      print(stmt)
      
      # create df
      df = spark.sql(stmt)
      
      # convert df -> lst + return
      return list(df.select('id').toPandas()['id'])
      
    

    The logic selects all data from a given quarter. The input array of picked numbers is used to make sure we are choosing new numbers without replacement. I am using an rand() function to order the data before limiting the result.

    #
    # 3 - call functions with tuples
    #
    
    # picks per qtr
    picks_per_qtr = [(2,1), (1,2), (3,3), (1,4), (2,5), (3,1)]
    
    # list of ids
    i = []
    
    # for each tuple, call business logic
    for items in picks_per_qtr:
      
      # choose n values from quarter q
      n, q = items
      
      # get an array of new ids
      r = pick_logic(n, q, i)
      
      # show picking
      print(r)
      
      # append to id list
      i = i + r
    
    print("final result is {}".format(i))
    

    I took the liberty to convert the numbers array to an array of tuples (picks, quarter). I left the debugging code in place so that you see the dynamic queries that are being executed and the results that are being returned.

    Unlike my first pass at this Spark SQL without the rand() order, one will get different results each time the code is executed. However, the final array with have only x elements equal to or less than the number of distinct ids.

    enter image description here

    Above image shows that the number 9 has not been picked yet. Also, I was not told if the quarter in which a number was picked needs to be recorded. If that is true, modify the code to return a list of tuples (id, qtr).

    In short, the logic is sound and I hope it helps solve you business problem.