We had similar requirement and solved it in following way:
data = [
(1, "buy"),
(1, "buy"),
(1, "sell"),
(1, "sell"),
(2, "sell"),
(2, "buy"),
(2, "sell"),
(2, "sell"),
]
df = spark.createDataFrame(data, ["userId", "action"])
Use Window
functionality to create serial row numbers. Also compute count of records by each userId
. This will be helpful to compute percentage of records to filter.
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
window = Window.partitionBy(df["userId"]).orderBy(df["userId"])
df_count = df.groupBy("userId").count().withColumnRenamed("userId", "userId_grp")
df = df.join(df_count, col("userId") == col("userId_grp"), "left").drop("userId_grp")
df = df.select("userId", "action", "count", row_number().over(window).alias("row_number"))
df.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
| 1| buy| 4| 1|
| 1| buy| 4| 2|
| 1| sell| 4| 3|
| 1| sell| 4| 4|
| 2| sell| 4| 1|
| 2| buy| 4| 2|
| 2| sell| 4| 3|
| 2| sell| 4| 4|
+------+------+-----+----------+
Filter training records by required percentage:
n = 75
df_train = df.filter(col("row_number") <= col("count") * n / 100)
df_train.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
| 1| buy| 4| 1|
| 1| buy| 4| 2|
| 1| sell| 4| 3|
| 2| sell| 4| 1|
| 2| buy| 4| 2|
| 2| sell| 4| 3|
+------+------+-----+----------+
And remaining records go to the test set:
df_test = df.alias("df").join(df_train.alias("tr"), (col("df.userId") == col("tr.userId")) & (col("df.row_number") == col("tr.row_number")), "leftanti")
df_test.show()
+------+------+-----+----------+
|userId|action|count|row_number|
+------+------+-----+----------+
| 1| sell| 4| 4|
| 2| sell| 4| 4|
+------+------+-----+----------+