2

I want to use the following rdd

rdd = sc.parallelize([("K1", "e", 9), ("K1", "aaa", 9), ("K1", "ccc", 3), ("K1", "ddd", 9),
("B1", "qwe", 4), ("B1", "rty", 7), ("B1", "iop", 8), ("B1", "zxc", 1)])

to get the output

[('K1', 'aaa', 9),
 ('K1', 'ddd', 9),
 ('K1', 'e', 9),
 ('B1', 'iop', 8),
 ('B1', 'rty', 7),
 ('B1', 'qwe', 4)]

I referred to Get Top 3 values for every key in a RDD in Spark and used the following code

from heapq import nlargest
rdd.groupBy(
    lambda x: x[0]
).flatMap(
    lambda g: nlargest(3, g[1], key=lambda x: (x[2],x[1]))
).collect()

However, I can only derive

[('K1', 'e', 9),
 ('K1', 'ddd', 9),
 ('K1', 'aaa', 9),
 ('B1', 'iop', 8),
 ('B1', 'qwe', 7),
 ('B1', 'rty', 4)]

How shall I do?

Samson
  • 75
  • 7
  • Do you want the keys to be sorted? Sorting is an computationally expensive operation. Spark shuffles items for performance, so the order is arbitrary most of the times. – pissall Oct 15 '19 at 04:06
  • Yeah, I want them to be sorted. How shall I do to convert the output derived from that code to the one I finally need? – Samson Oct 15 '19 at 04:12

2 Answers2

1

It is a sorting problem actually, but sorting is a computationally very expensive process due to shuffling. But you can try:

rdd2 = rdd.groupBy(
    lambda x: x[0]
).flatMap(
    lambda g: nlargest(3, g[1], key=lambda x: (x[2],x[1]))
)

rdd2.sortBy(lambda x: x[1], x[2]).collect()
# [('K1', 'aaa', 9), ('K1', 'ddd', 9), ('K1', 'e', 9), ('B1', 'iop', 8), ('B1', 'qwe', 4), ('B1', 'rty', 7)]

I have sorted it using the first and second value of the tuples.

Also note, q comes before r alphabetically. So your mentioned expected output is off and misleading.

pissall
  • 7,109
  • 2
  • 25
  • 45
  • Thank you. You are right. I have revised my question along what you commented. Following your code, it works well. – Samson Oct 15 '19 at 16:43
0

If you are open for dataframe , you can use windows function with rank

Inspired from here

import pyspark.sql.functions as f
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window

spark = SparkSession.builder.appName('test').master("local[*]").getOrCreate()

df = spark.createDataFrame([
    ("K1", "e", 9),
    ("K1", "aaa", 9),
    ("K1", "ccc", 3),
    ("K1", "ddd", 9),
    ("B1", "qwe", 4),
    ("B1", "rty", 7),
    ("B1", "iop", 8),
    ("B1", "zxc", 1)], ['A', 'B', 'C']
    )

w = Window.partitionBy('A').orderBy(df.C.desc())
df.select('*', F.rank().over(w).alias('rank')).filter("rank<4").drop('rank').show()


+---+---+---+
|  A | B | C|
+---+---+---+
| B1 | iop | 8|
| B1 | rty | 7|
| B1 | qwe | 4|
| K1 | e | 9|
| K1 | aaa | 9|
| K1 | ddd | 9|
+---+---+---+
PIG
  • 599
  • 3
  • 13