Search code examples
pythonazuredatabricksazure-databricks

Group by value within range in Azure Databricks


Consider following data:

EventDate,Value
1.1.2019,11
1.2.2019,5
1.3.2019,6
1.4.2019,-15
1.5.2019,-20
1.6.2019,-30
1.7.2019,12
1.8.2019,20

I want to create groups of when these values are within thresholds:

 1. > 10
 2. <=10 >=-10
 3. >-10

The result should be with start and end of values in a certain state:

1.1.2019, 1.1.2019, [11]
1.2.2019, 1.3.2019, [5, 6]
1.4.2019, 1.6.2019, [-15, -20, -30]
1.7.2019, 1.8.2018, [12, 20]

I believe the answer is within the window function, but I am fairly new to databricks and I can't understand how to use it (yet).

Here is a working (python) solution based on looping through the dataframe as a list, however I would prefer a solution that works directly on the dataframe for performance.

from pyspark.sql.functions import *
import pandas as pd
STATETHRESHOLDCHARGE = 10
list = [{"eventDateTime":x["EventDate"], "value":x["Value"]} for x in dataframe.sort(dfArrayOneCast.EventDate).rdd.collect()]
cycles = []
previous = None
for row in list:
  currentState = 'charge'
  if row["value"] < STATETHRESHOLDCHARGE and row["value"] > (STATETHRESHOLDCHARGE * -1):
    currentState = 'idle'
  if row["value"] <= (STATETHRESHOLDCHARGE * -1):
    currentState = 'discharge'

  eventDateTime = row["eventDateTime"]
  if previous is None or previous["state"] != currentState:
    previous = {"start":row["eventDateTime"], "end":row["eventDateTime"], "values":[row["value"]], "timestamps":[row["eventDateTime"]], "state":currentState}
    cycles.append(previous)
  else:
    previous["end"] = row["eventDateTime"]
    previous["values"].append(row["value"])
    previous["timestamps"].append(row["eventDateTime"])

display(cycles)

Solution

  • Assuming you have above data in df data frame, let's take this piece by piece

    from pyspark.sql.functions import col, last, lag, udf, when, collect_list
    from pyspark.sql.types import StringType
    value = 'value'
    date = 'EventDate'
    valueBag = 'valueBag'
    
    def bagTransform(v):
      if v > 10:
        return 'charging'
      elif v < -10:
        return 'discharging'
      else:
        return 'idle'
    
    bagTransformUDF = udf(bagTransform, StringType())  
    
    withBaggedValue = df.withColumn(valueBag, bagTransformUDF(col(value)))
    

    So first we bagged values into ranges as you declared, now we can use lag to move a window over previous value:

    from pyspark.sql import Window
    windowSpec = Window.orderBy(date)
    prevValueBag = 'prevValueBag'
    bagBeginning = 'bagBeginning'
    
    withLag = (withBaggedValue
      .withColumn(prevValueBag, lag(withBaggedValue[valueBag]).over(windowSpec)))
    

    Now the fun part starts: we detect change points and temporarily assign there current event date or null:

    withInitialBeginnings = withLag.withColumn(bagBeginning, when((col(prevValueBag) != col(valueBag)) | col(prevValueBag).isNull(), col(date)).otherwise(None))
    

    and fill them in using last found value

    withFilledBeginnings = (withInitialBeginnings.withColumn(bagBeginning, 
                     last(col(bagBeginning), ignorenulls=True)
                     .over(windowSpec)))
    display(withFilledBeginnings)
    

    results table with that set we can simply aggregate over starting point

    aggregate = withFilledBeginnings.groupby(col(bagBeginning)).agg(collect_list(value))
    
    display(aggregate)
    

    aggregated results

    If you also need the end date you can do similar preprocessing using pyspark.sql.functions.lead which works symmetrically to last but in forward direction.