Search code examples

Split rows in train test based on user id PySpark

I have a PySpark dataframe containing multiple rows for each user:

userId action time
1 buy 8 AM
1 buy 9 AM
1 sell 2 PM
1 sell 3 PM
2 sell 10 AM
2 buy 11 AM
2 sell 2 PM
2 sell 3 PM

My goal is to split this dataset into a training and a test set in such a way that, for each userId, N % of the rows are in the training set and 100-N % rows are in the test set. For example, given N=75%, the training set will be

userId action time
1 buy 8 AM
1 buy 9 AM
1 sell 2 PM
2 sell 10 AM
2 buy 11 AM
2 sell 2 PM

and the test set will be

userId action time
1 sell 3 PM
2 sell 3 PM

Any suggestion? Rows are ordered according to column time and I don't think that Spark's RandomSplit may help as I cannot stratify the split on specific columns


  • We had similar requirement and solved it in following way:

    data = [
      (1, "buy"),
      (1, "buy"),
      (1, "sell"),
      (1, "sell"),
      (2, "sell"),
      (2, "buy"),
      (2, "sell"),
      (2, "sell"),
    df = spark.createDataFrame(data, ["userId", "action"])

    Use Window functionality to create serial row numbers. Also compute count of records by each userId. This will be helpful to compute percentage of records to filter.

    from pyspark.sql.window import Window
    from pyspark.sql.functions import col, row_number
    window = Window.partitionBy(df["userId"]).orderBy(df["userId"])
    df_count = df.groupBy("userId").count().withColumnRenamed("userId", "userId_grp")
    df = df.join(df_count, col("userId") == col("userId_grp"), "left").drop("userId_grp")
    df ="userId", "action", "count", row_number().over(window).alias("row_number"))
    |     1|   buy|    4|         1|
    |     1|   buy|    4|         2|
    |     1|  sell|    4|         3|
    |     1|  sell|    4|         4|
    |     2|  sell|    4|         1|
    |     2|   buy|    4|         2|
    |     2|  sell|    4|         3|
    |     2|  sell|    4|         4|

    Filter training records by required percentage:

    n = 75
    df_train = df.filter(col("row_number") <= col("count") * n / 100)
    |     1|   buy|    4|         1|
    |     1|   buy|    4|         2|
    |     1|  sell|    4|         3|
    |     2|  sell|    4|         1|
    |     2|   buy|    4|         2|
    |     2|  sell|    4|         3|

    And remaining records go to the test set:

    df_test = df.alias("df").join(df_train.alias("tr"), (col("df.userId") == col("tr.userId")) & (col("df.row_number") == col("tr.row_number")), "leftanti")
    |     1|  sell|    4|         4|
    |     2|  sell|    4|         4|