I want to use the following rdd
rdd = sc.parallelize([("K1", "e", 9), ("K1", "aaa", 9), ("K1", "ccc", 3), ("K1", "ddd", 9),
("B1", "qwe", 4), ("B1", "rty", 7), ("B1", "iop", 8), ("B1", "zxc", 1)])
to get the output
[('K1', 'aaa', 9),
('K1', 'ddd', 9),
('K1', 'e', 9),
('B1', 'iop', 8),
('B1', 'rty', 7),
('B1', 'qwe', 4)]
I referred to Get Top 3 values for every key in a RDD in Spark and used the following code
from heapq import nlargest
rdd.groupBy(
lambda x: x[0]
).flatMap(
lambda g: nlargest(3, g[1], key=lambda x: (x[2],x[1]))
).collect()
However, I can only derive
[('K1', 'e', 9),
('K1', 'ddd', 9),
('K1', 'aaa', 9),
('B1', 'iop', 8),
('B1', 'qwe', 7),
('B1', 'rty', 4)]
How shall I do?