I have a csv with multiple columns, but i am interested in three, age_group, gender and cause of death. It looks like the following.
|age_group|gender|cause_of_death|
+-------------+---+---------------+
| 7| F| 1|
| 8| M| 2|
| 10| F| 3|
| 7| M| 2|
| 9| F| 3|
Age group represent the classes, 7 means 70-80, and cause_of_death are number representation of reasons, 1 is heart attack, 2 is accident etc.
I have to find out top n causes of deaths for each gender and age_group. What i have tried right now is the following.
data.select("age_group","gender","cause_of_death")\
.groupBy("gender","age_group","cause_of_death").count()\
.sort(desc("age_group")).show()
But it gives me the count of all deaths arranged in descending order.
gender|age_group|cause_of_death| count|
+---+-------------+---------------+------+
| F| 11| 7|308181|
| F| 10| 7|231168|
| M| 10| 7|221172|
| M| 11| 7|157693|
| F| 11| null|149345|
| M| 9| 7|146186|
| F| 9| 7|114424|
| F| 10| null|114107|
| M| 8| 7|106339|
| M| 10| null|105508|
| M| 11| null| 75934|
| F| 8| 7| 70390|
| M| 9| null| 69363|
| M| 7| 7| 65634|
What i want is, for each age_group and gender, top n causes of death. How can i do that? Something like following, top 3 causes of deaths.
gender|age_group|cause_of_death| count|
+---+-------------+---------------+------+
| F| 11| 7|308181|
1|291242|
4|234231|
| F| 10| 7|231168|
3|221232|
2|192323|
| M| 10| 7|221172|
2|142323|
9| 12312
EDIT: The question in the comments does not answer my question, i tried it and i am not getting the correct results.
Code:
window = Window.partitionBy(data['age_group']).orderBy(data['cause_of_death'].desc())
data = data.select("age_group","gender","cause_of_death")
data.select('*', rank().over(window).alias('rank')).filter(col('rank') <= 5).show()
Results
age_group|gender|cause_of_death|rank|
+-------------+---+---------------+----+
| 12| M| 7| 1|
| 12| M| 7| 1|
| 12| F| 7| 1|
| 12| M| 7| 1|
| 12| F| 7| 1|
| 12| M| 7| 1|
| 12| M| 7| 1|
| 12| M| 7| 1|