Search code examples
scalaapache-spark-sqluser-defined-functionswindow-functionsdistinct-values

Identifying recurring values a column over a Window (Scala)


I have a data frame with two columns: "ID" and "Amount", each row representing a transaction of a particular ID and the transacted amount. My example uses the following DF:

val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),
  (2, 50),(1, 30),(2, 120))).toDF("ID","Amount")

I want to create a new column identifying whether said amount is a recurring value, i.e. occurs in any other transaction for the same ID, or not.

I have found a way to do this more generally, i.e. across the entire column "Amount", not taking into account the ID, using the following function:

def recurring_amounts(df: DataFrame, col: String) : DataFrame = {
  var df_to_arr = df.select(col).rdd.map(r => r(0).asInstanceOf[Double]).collect()
  var arr_to_map = df_to_arr.groupBy(identity).mapValues(_.size)
  var map_to_df = arr_to_map.toSeq.toDF(col, "Count")
  var df_reformat = map_to_df.withColumn("Amount", $"Amount".cast(DoubleType))
  var df_out = df.join(df_reformat, Seq("Amount"))
  return df_new
}

val df_output = recurring_amounts(df, "Amount")

This returns:

+---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120  |  3  |
| 1 | 120  |  3  |
| 2 |  40  |  1  |
| 2 |  50  |  1  | 
| 1 |  30  |  1  |
| 2 | 120  |  3  |
+---+------+-----+

which I can then use to create my desired binary variable to indicate whether the amount is recurring or not (yes if > 1, no otherwise).

However, my problem is illustrated in this example by the value 120, which is recurring for ID 1 but not for ID 2. My desired output therefore is:

 +---+------+-----+
|ID |Amount|Count|
+---+------+-----+
| 1 | 120  |  2  |
| 1 | 120  |  2  |
| 2 |  40  |  1  |
| 2 |  50  |  1  | 
| 1 |  30  |  1  |
| 2 | 120  |  1  |
+---+------+-----+

I've been trying to think of a way to apply a function using .over(Window.partitionBy("ID") but not sure how to go about it. Any hints would be much appreciated.


Solution

  • If you are good in sql, you can write sql query for your Dataframe. The first thing that you need to do is to register your Dataframeas a table in the spark's memory. After that you can write the sql on top of the table. Note that spark is the spark session variable.

    val df = sc.parallelize(Seq((1, 120),(1, 120),(2, 40),(2, 50),(1, 30),(2, 120))).toDF("ID","Amount")
    df.registerTempTable("transactions")
    spark.sql("select *,count(*) over(partition by ID,Amount) as Count from transactions").show()
    

    enter image description here

    Please let me know if you have any questions.