4

My requirement is to get the top N items from a dataframe.

I've this DataFrame:

val df = List(
      ("MA", "USA"),
      ("MA", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("OH", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("NY", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA"),
      ("CT", "USA")).toDF("value", "country")

I was able to map it to an RDD[((Int, String), Long)] colValCount: Read: ((colIdx, value), count)

((0,CT),5)
((0,MA),2)
((0,OH),4)
((0,NY),6)
((1,USA),17)

Now I need to get the top 2 items for each column index. So my expected output is this:

RDD[((Int, String), Long)]

((0,CT),5)
((0,NY),6)
((1,USA),17)

I've tried using freqItems api in DataFrame but it's slow.

Any suggestions are welcome.

Sam
  • 358
  • 2
  • 3
  • 15
  • 1
    I think you need some combination of `sort()` and `limit()`, but TBH I don't understand how you get your output. – pault Feb 13 '18 at 20:30

4 Answers4

3

For example:

import org.apache.spark.sql.functions._

df.select(lit(0).alias("index"), $"value")
   .union(df.select(lit(1), $"country"))
   .groupBy($"index", $"value")
   .count
  .orderBy($"count".desc)
  .limit(3)
  .show

// +-----+-----+-----+
// |index|value|count|
// +-----+-----+-----+
// |    1|  USA|   17|
// |    0|   NY|    6|
// |    0|   CT|    5|
// +-----+-----+-----+

where:

df.select(lit(0).alias("index"), $"value")
  .union(df.select(lit(1), $"country"))

creates a two column DataFrame:

// +-----+-----+
// |index|value|
// +-----+-----+
// |    0|   MA|
// |    0|   MA|
// |    0|   OH|
// |    0|   OH|
// |    0|   OH|
// |    0|   OH|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   NY|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    0|   CT|
// |    1|  USA|
// |    1|  USA|
// |    1|  USA|
// +-----+-----+

If you want specifically two values for each column:

import org.apache.spark.sql.DataFrame

def topN(df: DataFrame, key: String, n: Int)  = {
   df.select(
        lit(df.columns.indexOf(key)).alias("index"), 
        col(key).alias("value"))
     .groupBy("index", "value")
     .count
     .orderBy($"count")
     .limit(n)
}

topN(df, "value", 2).union(topN(df, "country", 2)).show
// +-----+-----+-----+ 
// |index|value|count|
// +-----+-----+-----+
// |    0|   MA|    2|
// |    0|   OH|    4|
// |    1|  USA|   17|
// +-----+-----+-----+

So like pault said - just "some combination of sort() and limit()".

Alper t. Turker
  • 34,230
  • 9
  • 83
  • 115
  • Some combination of `orderBy()` and `limit()`...basically what I said :-) – pault Feb 13 '18 at 21:05
  • @pault Yeah, I guess it is either that or some combination of `groupBy` and `agg` here with a window function from time to time :-) – Alper t. Turker Feb 13 '18 at 21:13
  • 3
    `.orderBy($"count".desc).limit(3)` gives the specified result in this case, but it doesn't give the top 2 items for each column index in the general case. – Kirk Broadhurst Feb 13 '18 at 21:14
  • Thanks let me try it out! – Sam Feb 13 '18 at 21:34
  • @user8371915 this works but still slow for TB of data as I'm also running for 170+ columns. The performance is comparable to freqItems API I mentioned earlier. It takes around 1+ hour to process the entire data. – Sam Feb 14 '18 at 20:44
  • This process throws GC exception after a while: Exception in thread "dispatcher-event-loop-39" java.lang.OutOfMemoryError: GC overhead limit exceeded – Sam Feb 14 '18 at 20:59
3

The easiest way to do this - a natural window function - is by writing SQL. Spark comes with SQL syntax, and SQL is a great and expressive tool for this problem.

Register your dataframe as a temp table, and then group and window on it.

spark.sql("""SELECT idx, value, ROW_NUMBER() OVER (PARTITION BY idx ORDER BY c DESC) as r 
             FROM (
               SELECT idx, value, COUNT(*) as c 
               FROM (SELECT 0 as idx, value FROM df UNION ALL SELECT 1, country FROM df) 
               GROUP BY idx, value) 
             HAVING r <= 2""").show()

I'd like to see if any of the procedural / scala approaches will let you perform the window function without an iteration or loop. I'm not aware of anything in the Spark API that would support it.

Incidentally, if you have an arbitrary number of columns you want to include then you can quite easily generate the inner section (SELECT 0 as idx, value ... UNION ALL SELECT 1, country, etc) dynamically using the list of columns.

Kirk Broadhurst
  • 27,836
  • 16
  • 104
  • 169
  • Thanks, I've tried using window functions it complained of GC overhead. I'll try out your solution too. – Sam Feb 13 '18 at 21:31
2

Given your last RDD:

val rdd =
  sc.parallelize(
    List(
      ((0, "CT"), 5),
      ((0, "MA"), 2),
      ((0, "OH"), 4),
      ((0, "NY"), 6),
      ((1, "USA"), 17)
    ))

rdd.filter(_._1._1 == 0).sortBy(-_._2).take(2).foreach(println)
> ((0,NY),6)
> ((0,CT),5)
rdd.filter(_._1._1 == 1).sortBy(-_._2).take(2).foreach(println)
> ((1,USA),17)

We first get items for a given column index (.filter(_._1._1 == 0)). Then we sort items by decreasing order (.sortBy(-_._2)). And finally, we take at most the 2 first elements (.take(2)), which takes only 1 element if the nbr of record is lower than 2.

Xavier Guihot
  • 54,987
  • 21
  • 291
  • 190
0

You can map each single partition using this helper function defined in Sparkz and then combine them together:

package sparkz.utils

import scala.reflect.ClassTag

object TopElements {
  def topN[T: ClassTag](elems: Iterable[T])(scoreFunc: T => Double, n: Int): List[T] =
    elems.foldLeft((Set.empty[(T, Double)], Double.MaxValue)) {
      case (accumulator@(topElems, minScore), elem) =>
        val score = scoreFunc(elem)
        if (topElems.size < n)
          (topElems + (elem -> score), math.min(minScore, score))
        else if (score > minScore) {
          val newTopElems = topElems - topElems.minBy(_._2) + (elem -> score)
          (newTopElems, newTopElems.map(_._2).min)
        }
        else accumulator
    }
      ._1.toList.sortBy(_._2).reverse.map(_._1)
}

Source: https://github.com/gm-spacagna/sparkz/blob/master/src/main/scala/sparkz/utils/TopN.scala

Gianmario Spacagna
  • 1,270
  • 14
  • 12