I have a PySpark dataframe containing multiple rows for each user:
userId | action | time |
---|---|---|
1 | buy | 8 AM |
1 | buy | 9 AM |
1 | sell | 2 PM |
1 | sell | 3 PM |
2 | sell | 10 AM |
2 | buy | 11 AM |
2 | sell | 2 PM |
2 | sell | 3 PM |
My goal is to split this dataset into a training and a test set in such a way that, for each userId
, N % of the rows are in the training set and 100-N % rows are in the test set. For example, given N=75%, the training set will be
userId | action | time |
---|---|---|
1 | buy | 8 AM |
1 | buy | 9 AM |
1 | sell | 2 PM |
2 | sell | 10 AM |
2 | buy | 11 AM |
2 | sell | 2 PM |
and the test set will be
userId | action | time |
---|---|---|
1 | sell | 3 PM |
2 | sell | 3 PM |
Any suggestion? Rows are ordered according to column time and I don't think that Spark's RandomSplit
may help as I cannot stratify the split on specific columns
We had similar requirement and solved it in following way:
data = [
(1, "buy"),
(1, "buy"),
(1, "sell"),
(1, "sell"),
(2, "sell"),
(2, "buy"),
(2, "sell"),
(2, "sell"),
]
df = spark.createDataFrame(data, ["userId", "action"])
Use Window
functionality to create serial row numbers. Also compute count of records by each userId
. This will be helpful to compute percentage of records to filter.
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
window = Window.partitionBy(df["userId"]).orderBy(df["userId"])
df_count = df.groupBy("userId").count().withColumnRenamed("userId", "userId_grp")
df = df.join(df_count, col("userId") == col("userId_grp"), "left").drop("userId_grp")
df = df.select("userId", "action", "count", row_number().over(window).alias("row_number"))
df.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
| 1| buy| 4| 1|
| 1| buy| 4| 2|
| 1| sell| 4| 3|
| 1| sell| 4| 4|
| 2| sell| 4| 1|
| 2| buy| 4| 2|
| 2| sell| 4| 3|
| 2| sell| 4| 4|
+------+------+-----+----------+
Filter training records by required percentage:
n = 75
df_train = df.filter(col("row_number") <= col("count") * n / 100)
df_train.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
| 1| buy| 4| 1|
| 1| buy| 4| 2|
| 1| sell| 4| 3|
| 2| sell| 4| 1|
| 2| buy| 4| 2|
| 2| sell| 4| 3|
+------+------+-----+----------+
And remaining records go to the test set:
df_test = df.alias("df").join(df_train.alias("tr"), (col("df.userId") == col("tr.userId")) & (col("df.row_number") == col("tr.row_number")), "leftanti")
df_test.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
| 1| sell| 4| 4|
| 2| sell| 4| 4|
+------+------+-----+----------+