1

I have found the following code for selecting n rows from dataframe grouped by unique_id.

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.row_number

val window = Window.partitionBy("userId").orderBy($"rating".desc)

dataframe.withColumn("r", row_number.over(window)).where($"r" <= n)

I have tried the following:

from pyspark.sql.functions import row_number, desc
from pyspark.sql.window import Window

w = Window.partitionBy(post_tags.EntityID).orderBy(post_tags.Weight)
newdata=post_tags.withColumn("r", row_number.over(w)).where("r" <= 3)

I get the following error:

AttributeError: 'function' object has no attribute 'over'

Please help me on the same.

Nikhil Baby
  • 863
  • 3
  • 10
  • 22
  • The $ shows error. But I found answer from another post. `from pyspark.sql.window import Window from pyspark.sql.functions import rank, col window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc()) df.select('*', rank().over(window).alias('rank')) .filter(col('rank') <= 2) .show() ` – Nikhil Baby Oct 23 '17 at 05:58
  • 1
    In the `where($"r" <= n)` you removed the `$` but didn't replace it with anything. Try changing to `col('r') <= 3`. – Shaido Oct 23 '17 at 05:58

1 Answers1

1

I found the answer to this:

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 

Credits to @mtoto for his answer https://stackoverflow.com/a/38398563/5165377

Nikhil Baby
  • 863
  • 3
  • 10
  • 22