Search code examples
apache-sparkpysparkapache-spark-sqlspark-structured-streaming

Sum of array of dictionaries depending on value condition pyspark (spark structured streaming)


I have the following schema

        tick_by_tick_schema = StructType([
            StructField('localSymbol', StringType()),
            StructField('time', StringType()),
            StructField('open', StringType()),
            StructField('previous_price', StringType()),
            StructField('tickByTicks', ArrayType(StructType([
                StructField('price', StringType()),
                StructField('size', StringType()),
                StructField('specialConditions', StringType()),
            ])))
        ])

and I have the following dataframe (in spark structured streaming):

+-----------+--------------------------------+--------------+----------------------------------------------------+
|localSymbol|time                            |previous_price|tickByTicks                                         |
+-----------+--------------------------------+--------------+----------------------------------------------------+
|BABA       |2021-06-10 19:25:38.154245+00:00|213.76        |[{213.75, 100, }]                                   |
|BABA       |2021-06-10 19:25:38.155229+00:00|213.76        |[{213.75, 100, }, {213.78, 100, }, {213.78, 200, }] |
|BABA       |2021-06-10 19:25:39.662033+00:00|213.73        |[{213.72, 100, }]                                   |
|BABA       |2021-06-10 19:25:39.662655+00:00|213.72        |[{213.72, 100, }, {213.73, 100, }]                  |                                                                                
+-----------+--------------------------------+--------------+----------------------------------------------------+

I would like to create two columns depending on the next logic:

Column_low: WHEN tickByTicks.price < previous_price THEN sum(tickByTicks.size)
Column_high: when tickByTicks.price > previous_price THEN sum(tickByTicks.size)

the result will be:

+-----------+--------------------------------+--------------+----------------------------------------------------+----------+-----------+
|localSymbol|time                            |previous_price|tickByTicks                                         |Column_low|Column_high|
+-----------+--------------------------------+--------------+----------------------------------------------------+----------+-----------+
|BABA       |2021-06-10 19:25:38.154245+00:00|213.76        |[{213.75, 100, }]                                   |100       |0          |
|BABA       |2021-06-10 19:25:38.155229+00:00|213.76        |[{213.75, 100, }, {213.78, 100, }, {213.78, 200, }] |100       |300        |
|BABA       |2021-06-10 19:25:39.662033+00:00|213.73        |[{213.72, 100, }]                                   |100       |0          |
|BABA       |2021-06-10 19:25:39.662655+00:00|213.72        |[{213.72, 100, }, {213.73, 100, }]                  |0         |100        |                                                                  
+-----------+--------------------------------+--------------+----------------------------------------------------+----------+-----------+

I have tried to do something similar but I have not achieved the expected result

        tick_by_tick_data_processed = kafka_df_structured_with_tick_by_tick_data_values.select(
            f.col('localSymbol'),
            f.col('time'),
            f.col('previous_price'),
            f.col('tickByTicks'),
            f.expr("aggregate(filter(tickByTicks.size, x -> x > previous_price), 0D, (x, acc) -> acc + x)")
        ).show(30,False)

Solution

  • I can't test my solution, but I think this may work:

    tick_by_tick_data_processed = kafka_df_structured_with_tick_by_tick_data_values.select(
                f.col('localSymbol'),
                f.col('time'),
                f.col('previous_price'),
                f.col('tickByTicks'),
                f.expr("aggregate(tickByTicks, 0D, (acc, tick) -> IF(tick.price < previous_price, acc + tick.size, acc))").alias("Column_low"),
                f.expr("aggregate(tickByTicks, 0D, (acc, tick) -> IF(tick.price > previous_price, acc + tick.size, acc))").alias("Column_high"))