Search code examples
pysparkspark-structured-streaming

create a column to accumulate the data in an array psypark


I need to create a field to accumulate the data in an array. I have the following dataframe:

+----------------+-------------------+----------+---------------------+
|localSymbol_drop|end_window_drop    |detect_DRM|last_detect_price_DRM|
+----------------+-------------------+----------+---------------------+
|BABA            |2021-06-15 16:36:30|NO        |NA                   |
|BABA            |2021-06-15 16:37:00|NO        |NA                   |
|BABA            |2021-06-15 16:37:30|YES       |211.85               |
|BABA            |2021-06-15 16:38:00|NO        |NA                   |
|BABA            |2021-06-15 16:38:30|NO        |NA                   |
|BABA            |2021-06-15 16:40:30|NO        |NA                   |
|BABA            |2021-06-15 16:41:00|YES       |211.91               |
|BABA            |2021-06-15 16:42:00|NO        |NA                   |
|BABA            |2021-06-15 16:42:30|YES       |211.83               |
+----------------+-------------------+----------+---------------------+

and the result will be:

+----------------+-------------------+----------+---------------------+----------------------------------------+
|localSymbol_drop|end_window_drop    |detect_DRM|last_detect_price_DRM|accum_array                             |
+----------------+-------------------+----------+---------------------+----------------------------------------+
|BABA            |2021-06-15 16:36:30|NO        |NA                   |[NA]                                    |
|BABA            |2021-06-15 16:37:00|NO        |NA                   |[NA,NA]                                 |
|BABA            |2021-06-15 16:37:30|YES       |211.85               |[NA,NA,211.85]                          |
|BABA            |2021-06-15 16:38:00|NO        |NA                   |[NA,NA,211.85,NA]                       |
|BABA            |2021-06-15 16:38:30|NO        |NA                   |[NA,NA,211.85,NA,NA]                    |
|BABA            |2021-06-15 16:40:30|NO        |NA                   |[NA,NA,211.85,NA,NA,NA]                 |
|BABA            |2021-06-15 16:41:00|YES       |211.91               |[NA,NA,211.85,NA,NA,NA,211.91]          |
|BABA            |2021-06-15 16:42:00|NO        |NA                   |[NA,NA,211.85,NA,NA,NA,211.91,NA]       |
|BABA            |2021-06-15 16:42:30|YES       |211.83               |[NA,NA,211.85,NA,NA,NA,211.91,NA,211.83]|
+----------------+-------------------+----------+---------------------+----------------------------------------+

Any idea? Thank you!!


Solution

  • For my solution you first need to create an index on your dataframe:

    1)
    from pyspark.sql.functions import row_number
    from pyspark.sql.window import Window
    
    w = Window.orderBy("last_detect_price_DRM") 
    df = df.withColumn("index", row_number().over(w))
    

    When you have an index on your dataframe you need to get all values from the column you want to accumulate and sort that list(so that it is in the same order as your dataframe):

    2)
    my_list = 
    df.select(f.collect_list('last_detect_price_DRM')).first()[0]
    my_list.sort()
    

    Now you just need to create an UserDefinedFunction which takes the index as an input and returns all elements in the list till that given index. After that you just need to call the function withColumn('columnName', udf) on your dataframe

    3)
    from pyspark.sql.functions import col, udf
    from pyspark.sql.types import StringType, ArrayType
    
    def custom_func(index):
        return my_list[0:index]
    
    
    custom_func = udf(custom_func, ArrayType(StringType()))
    
    df = df.withColumn('acc', custom_func(col('index')))
    

    That will accumulate all values in a given column.

    1.
    +------+------+
    |    _1|    _2|
    +------+------+
    |  Java| 20000|
    |Python|100000|
    | Scala|  3000|
    +------+------+
    
    2.
     +------+------+-----+
    |    _1|    _2|index|
    +------+------+-----+
    |Python|100000|    1|
    |  Java| 20000|    2|
    | Scala|  3000|    3|
    +------+------+-----+
    
    3.
    +------+------+-----+--------------------+
    |    _1|    _2|index|                 acc|
    +------+------+-----+--------------------+
    |Python|100000|    1|            [100000]|
    |  Java| 20000|    2|     [100000, 20000]|
    | Scala|  3000|    3|[100000, 20000, 3...|
    +------+------+-----+--------------------+