0

I have a csv with multiple columns, but i am interested in three, age_group, gender and cause of death. It looks like the following.

|age_group|gender|cause_of_death|
+-------------+---+---------------+
|            7|  F|              1|
|            8|  M|              2|
|           10|  F|              3|
|            7|  M|              2|
|            9|  F|              3|

Age group represent the classes, 7 means 70-80, and cause_of_death are number representation of reasons, 1 is heart attack, 2 is accident etc.

I have to find out top n causes of deaths for each gender and age_group. What i have tried right now is the following.

data.select("age_group","gender","cause_of_death")\
.groupBy("gender","age_group","cause_of_death").count()\
.sort(desc("age_group")).show()

But it gives me the count of all deaths arranged in descending order.

gender|age_group|cause_of_death| count|
+---+-------------+---------------+------+
|  F|           11|              7|308181|
|  F|           10|              7|231168|
|  M|           10|              7|221172|
|  M|           11|              7|157693|
|  F|           11|           null|149345|
|  M|            9|              7|146186|
|  F|            9|              7|114424|
|  F|           10|           null|114107|
|  M|            8|              7|106339|
|  M|           10|           null|105508|
|  M|           11|           null| 75934|
|  F|            8|              7| 70390|
|  M|            9|           null| 69363|
|  M|            7|              7| 65634|

What i want is, for each age_group and gender, top n causes of death. How can i do that? Something like following, top 3 causes of deaths.

    gender|age_group|cause_of_death| count|
    +---+-------------+---------------+------+
    |  F|           11|              7|308181|
                                     1|291242|
                                     4|234231|

    |  F|           10|              7|231168|
                                     3|221232|
                                     2|192323|

    |  M|           10|              7|221172|
                                     2|142323|
                                     9| 12312

EDIT: The question in the comments does not answer my question, i tried it and i am not getting the correct results.

Code:

window = Window.partitionBy(data['age_group']).orderBy(data['cause_of_death'].desc())
data = data.select("age_group","gender","cause_of_death")
data.select('*', rank().over(window).alias('rank')).filter(col('rank') <= 5).show() 

Results

age_group|gender|cause_of_death|rank|
+-------------+---+---------------+----+
|           12|  M|              7|   1|
|           12|  M|              7|   1|
|           12|  F|              7|   1|
|           12|  M|              7|   1|
|           12|  F|              7|   1|
|           12|  M|              7|   1|
|           12|  M|              7|   1|
|           12|  M|              7|   1|
sid0972
  • 141
  • 1
  • 4
  • 13
  • The answer to that question compares the value, whereas i have to perform a computation to get the top values. – sid0972 Feb 22 '19 at 12:52
  • 2
    The window should be `Window.partitionBy("gender", "age_group").orderBy(col("count").desc())` and applied on already aggregated data, i.e. `data.select("age_group","gender","cause_of_death") .groupBy("gender","age_group","cause_of_death").count().withColumn("rank", rank().over(window).alias('"rank"))` – 10465355 Feb 22 '19 at 13:06
  • Thanks. Now i get the correct output. Can you post it as an answer, so that i can accept it? Also, how do i limit the output to top n ranks? For example, 5? – sid0972 Feb 22 '19 at 13:18
  • The filter should stay as is in the linked question. – 10465355 Feb 22 '19 at 13:21

0 Answers0