Search code examples
pysparkapache-spark-sqlsparkr

Identifying values that revert in Spark


I have a Spark DataFrame of customers as shown below.

#SparkR code
customers <- data.frame(custID = c("001", "001", "001", "002", "002", "002", "002"),
  date = c("2017-02-01", "2017-03-01", "2017-04-01", "2017-01-01", "2017-02-01", "2017-03-01", "2017-04-01"),
  value = c('new', 'good', 'good', 'new', 'good', 'new', 'bad'))
customers <- createDataFrame(customers)
display(customers)

custID|  date     | value
--------------------------
001   | 2017-02-01| new
001   | 2017-03-01| good
001   | 2017-04-01| good
002   | 2017-01-01| new
002   | 2017-02-01| good
002   | 2017-03-01| new
002   | 2017-04-01| bad

In the first month observation for a custID the customer gets a value of 'new'. Thereafter they are classified as 'good' or 'bad'. However, it is possible for a customer to revert from 'good' or 'bad' back to 'new' in the case that they open a second account. When this happens I want to tag the customer with '2' instead of '1', to indicate that they opened a second account, as shown below. How can I do this in Spark? Either SparkR or PySpark commands work.

#What I want to get 
custID|  date     | value | tag
--------------------------------
001   | 2017-02-01| new   | 1
001   | 2017-03-01| good  | 1
001   | 2017-04-01| good  | 1
002   | 2017-01-01| new   | 1
002   | 2017-02-01| good  | 1
002   | 2017-03-01| new   | 2
002   | 2017-04-01| bad   | 2

Solution

  • In pyspark:

    from pyspark.sql import functions as f
    
    spark = SparkSession.builder.getOrCreate()
    
    # df is equal to your customers dataframe
    df = spark.read.load('file:///home/zht/PycharmProjects/test/text_file.txt', format='csv', header=True, sep='|').cache()
    
    df_new = df.filter(df['value'] == 'new').withColumn('tag', f.rank().over(Window.partitionBy('custID').orderBy('date')))
    df = df_new.union(df.filter(df['value'] != 'new').withColumn('tag', f.lit(None)))
    df = df.withColumn('tag', f.collect_list('tag').over(Window.partitionBy('custID').orderBy('date'))) \
        .withColumn('tag', f.UserDefinedFunction(lambda x: x.pop(), IntegerType())('tag'))
    
    df.show()
    

    And output:

    +------+----------+-----+---+                                                   
    |custID|      date|value|tag|
    +------+----------+-----+---+
    |   001|2017-02-01|  new|  1|
    |   001|2017-03-01| good|  1|
    |   001|2017-04-01| good|  1|
    |   002|2017-01-01|  new|  1|
    |   002|2017-02-01| good|  1|
    |   002|2017-03-01|  new|  2|
    |   002|2017-04-01|  bad|  2|
    +------+----------+-----+---+
    

    By the way, pandas can do that easy.