Search code examples
pythonpysparkaggregatedatediffpartition

Pyspark use partition or groupby with agg and datediff


I'm new to Pyspark. I would like to find the products not seen after 10 days from the first day they entered the store. And create a column in dataframe and set it to 1 for these products and 0 for the rest.

First I need to group the data based on product_id, then find the maximum of the seen_date. And finally calculate the difference between import_date and max(seen_date) in the groups. And finally create a new column based on the value of date_diff in each group.

Following is the code I used to first get the difference between the import_date and seen_date, but it gives error:

from pyspark.sql.window import Window
from pyspark.sql import functions as F

w = (Window()
    .partitionBy(df.product_id)
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))

df.withColumn("date_diff", F.datediff(F.max(F.from_unixtime(F.col("import_date")).over(w)), F.from_unixtime(F.col("seen_date"))))

Error:

AnalysisException: It is not allowed to use a window function inside an aggregate function. Please use the inner window function in a sub-query.

This is the rest of my code to define a new column based on the date_diff:

not_seen = udf(lambda x: 0 if x >10 else 1, IntegerType())
df = df.withColumn('not_seen', not_seen("date_diff"))

Q: Can someone provide a fix for this code or a better approach to solve this problem?

sample data generation:

columns = ["product_id","import_date", "seen_date"]
data = [("123", "2014-05-06", "2014-05-07"),
        ("123", "2014-05-06", "2014-06-11"),
        ("125", "2015-01-02", "2015-01-03"),
        ("125", "2015-01-02", "2015-01-04"),
        ("128", "2015-08-06", "2015-08-25")]
dfFromData2 = spark.createDataFrame(data).toDF(*columns)
dfFromData2 = dfFromData2.withColumn("import_date",F.unix_timestamp(F.col("import_date"),'yyyy-MM-dd'))
dfFromData2 = dfFromData2.withColumn("seen_date",F.unix_timestamp(F.col("seen_date"),'yyyy-MM-dd'))

+----------+-----------+----------+
|product_id|import_date| seen_date|
+----------+-----------+----------+
|       123| 1399334400|1399420800|
|       123| 1399334400|1402444800|
|       125| 1420156800|1420243200|
|       125| 1420156800|1420329600|
|       128| 1438819200|1440460800|
+----------+-----------+----------+

Solution

  • columns = ["product_id","import_date", "seen_date"]
    data = [("123", "2014-05-06", "2014-05-07"),
            ("123", "2014-05-06", "2014-06-11"),
            ("125", "2015-01-02", "2015-01-03"),
            ("125", "2015-01-02", "2015-01-04"),
            ("128", "2015-08-06", "2015-08-25")]
    
    df = spark.createDataFrame(data).toDF(*columns)
    df = df.withColumn("import_date",F.to_date(F.col("import_date"),'yyyy-MM-dd'))
    df = df.withColumn("seen_date",F.to_date(F.col("seen_date"),'yyyy-MM-dd'))
    
    from pyspark.sql.window import Window
    from pyspark.sql import functions as F
    
    w = (Window()
        .partitionBy(df.product_id)
        .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
    
    df\
        .withColumn('max_import_date', F.max(F.col("import_date")).over(w))\
        .withColumn("date_diff", F.datediff(F.col('seen_date'), F.col('max_import_date')))\
        .withColumn('not_seen', F.when(F.col('date_diff') > 10, 0).otherwise(1))\
        .show()
    
    +----------+-----------+----------+---------------+---------+--------+
    |product_id|import_date| seen_date|max_import_date|date_diff|not_seen|
    +----------+-----------+----------+---------------+---------+--------+
    |       123| 2014-05-06|2014-05-07|     2014-05-06|        1|       1|
    |       123| 2014-05-06|2014-06-11|     2014-05-06|       36|       0|
    |       125| 2015-01-02|2015-01-03|     2015-01-02|        1|       1|
    |       125| 2015-01-02|2015-01-04|     2015-01-02|        2|       1|
    |       128| 2015-08-06|2015-08-25|     2015-08-06|       19|       0|
    +----------+-----------+----------+---------------+---------+--------+