Data preparation :
val df = Seq(("Josh",94),("Josh",87),("Amanda",96),("Karen",78),("Amanda",90),("Josh",88)).toDF("Name","Grade")
Perform the following , only if your Data is skewed
for a Name
:
Add a random number, and filter the top 3 random numbers for each Name
.
val df2 = df.withColumn("random", round(rand()*10))
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy("Name").orderBy("random")
val df3 = df2.withColumn("row_number",row_number.over(windowSpec))
.filter($"row_number" <= 3)
Now, aggregate the values for each Name
and duplicate 3 times to ensure we have atleast 3 records for each Name
. Then finally take 1st 3 values, and explode
df4.groupBy("Name").agg(collect_list("Grade") as "grade_list")
.withColumn("temp_list", slice( flatten(array_repeat($"grade_list", 3)), 1,3))
.select($"Name",explode($"temp_list") as "Grade").show
Notes :
- Since the above code will have max 3 values in
grade_list
, hence Duplicating it 3 times won't harm.
- Incase you don't use the
Window
step, you can have a combination of when( size($"grade_list") === n, ).otherwise()
to above unnecessary duplication.