from pyspark.sql import SparkSession
import pyspark.sql.functions as f
from pyspark.sql.types import *
import pandas as pd
from time import perf_counter
# get a spark session
spark = SparkSession.builder.appName('learn').getOrCreate()
# create dataset
schema = StructType([
StructField('c1', StringType(), nullable=True),
StructField('c2', StringType(), nullable=True),
StructField('value', DoubleType(), nullable=True),
])
import random
data = [(random.choice(list('ABC')), random.choice(list('abc')), random.random()) for _ in range(100)]
df = spark.createDataFrame(data, schema=schema).drop_duplicates()
df.createOrReplaceTempView('tmp_view')
# window function (SQL)
query ="""
SELECT * FROM
(SELECT c1, c2, value, dense_rank() OVER (PARTITION BY c1 ORDER BY value ASC) as rank
FROM tmp_view) x
WHERE x.rank <= 3
"""
res = spark.sql(query).orderBy(['c1', 'rank'], ascending=True)
res.show()
# window function (dataframe API)
from pyspark.sql.window import Window
w = Window.partitionBy("c1").orderBy('value')
res = df.withColumn('rank', f.dense_rank().over(w)).filter(f.col('rank')<=3).orderBy(['c1', 'rank'], ascending=True)
res.show()
The produced dataframes should look like
df.show(n=5, truncate=False)
+---+---+-------------------+
|c1 |c2 |value |
+---+---+-------------------+
|C |a |0.38262849793622566|
|B |b |0.4117068824287389 |
|C |a |0.2622908081454347 |
|A |b |0.5371458199115897 |
|B |a |0.18469916475187298|
+---+---+-------------------+
only showing top 5 rows
res.show(n=5, truncate=False)
+---+---+--------------------+----+
|c1 |c2 |value |rank|
+---+---+--------------------+----+
|A |c |0.005163400336233082|1 |
|A |c |0.005434604534087506|2 |
|A |c |0.025183426923212848|3 |
|B |a |0.02414753220919985 |1 |
|B |b |0.04950658758922166 |2 |
+---+---+--------------------+----+
only showing top 5 rows