0

I wouldn't expect this to be difficult, but I'm having trouble understanding how to take the average of a column in my spark dataframe.

The dataframe looks like:

+-------+------------+--------+------------------+
|Private|Applications|Accepted|              Rate|
+-------+------------+--------+------------------+
|    Yes|         417|     349|0.8369304556354916|
|    Yes|        1899|    1720|0.9057398630858347|
|    Yes|        1732|    1425|0.8227482678983834|
|    Yes|         494|     313|0.6336032388663968|
|     No|        3540|    2001|0.5652542372881356|
|     No|        7313|    4664|0.6377683577191303|
|    Yes|         619|     516|0.8336025848142165|
|    Yes|         662|     513|0.7749244712990937|
|    Yes|         761|     725|0.9526938239159002|
|    Yes|        1690|    1366| 0.808284023668639|
|    Yes|        6075|    5349|0.8804938271604938|
|    Yes|         632|     494|0.7816455696202531|
|     No|        1208|     877|0.7259933774834437|
|    Yes|       20192|   13007|0.6441660063391442|
|    Yes|        1436|    1228|0.8551532033426184|
|    Yes|         392|     351|0.8954081632653061|
|    Yes|       12586|    3239|0.2573494358811378|
|    Yes|        1011|     604|0.5974282888229476|
|    Yes|         848|     587|0.6922169811320755|
|    Yes|        8728|    5201|0.5958982584784601|
+-------+------------+--------+------------------+

I want to return the average of the Rate column when Private is equal to "Yes". How can I do this?

Jacob Myer
  • 479
  • 5
  • 22

4 Answers4

1

Try

df.filter(df['Private'] == 'Yes').agg({'Rate': 'avg'}).collect()[0]
Vishnudev Krishnadas
  • 10,679
  • 2
  • 23
  • 55
  • I'm not sure filter works this way. Filter works on the index and column headers. – cs95 Feb 09 '20 at 18:33
  • I have tried `privateRate = df.filter(df['Private'] == 'Yes').agg({'Rate': 'avg'})`, but `print(privateRate)` returns `DataFrame[avg(Rate): double]`. Seems like its close, but I need to see the actual number @Vishnudev – Jacob Myer Feb 09 '20 at 18:46
  • spyder is warning me "undefined name *display*" – Jacob Myer Feb 09 '20 at 18:53
1

A third version to do the same would be:

from pyspark.sql.functions import col, avg
df_avg = df.filter(df["Private"] == "Yes").agg(avg(col("Rate")))
df_avg.show()
BICube
  • 4,451
  • 1
  • 23
  • 44
  • I think this, along with Vishnudev 's answer are working, but the answer isn't displaying properly. Like I mentioned before, printing this gives me:`DataFrame[avg(Rate): double]` and I need to see the number. I'm curious about the display function you mentioned, it seems like I need to import something first for that to work. – Jacob Myer Feb 09 '20 at 18:56
  • df_avg.show() gives me a very long error (one I have been seeing a lot when trying very different things) – Jacob Myer Feb 09 '20 at 19:00
1

This would work in scala. pyspark code should be very similar.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

val df = List(
("yes", 10),
("yes", 30),
("No", 40)).toDF("private", "rate")

val df = l.toDF(List("private", "rate"))

val window =Window.partitionBy($"private")

df.
    withColumn("avg", 
                when($"private" === "No", null).
                otherwise(avg($"rate").over(window))
            ).
    show()

Input DF

+-------+----+
|private|rate|
+-------+----+
|    yes|  10|
|    yes|  30|
|     No|  40|
+-------+----+

output df

+-------+----+----+
|private|rate| avg|
+-------+----+----+
|     No|  40|null|
|    yes|  10|20.0|
|    yes|  30|20.0|
+-------+----+----+
Gaurang Shah
  • 11,764
  • 9
  • 74
  • 137
0

Try:

from pyspark.sql.functions import col, mean, lit

df.where(col("Private")==lit("Yes")).select(mean(col("Rate"))).collect()
Grzegorz Skibinski
  • 12,624
  • 2
  • 11
  • 34
  • spyder gives me a code analysis warning, "'lit' may be undefined or defined from star imports". This also causes an error while running the code – Jacob Myer Feb 09 '20 at 18:51
  • 1
    Ou, try now - ```F.lit(...)``` should do – Grzegorz Skibinski Feb 09 '20 at 19:06
  • I'm receiving a big long error. I've seen it several times on this project but it doesn't always come up. I'm beginning to wonder if the problem isn't this part of my code but something else? Which would be odd because I can run up to these lines without issue. – Jacob Myer Feb 09 '20 at 19:11
  • Maybe it's this: https://stackoverflow.com/questions/46799137/how-to-supress-star-imports-warnings-from-spyder-ide – Grzegorz Skibinski Feb 09 '20 at 19:13
  • Out of curiosity- try it now (without star import) – Grzegorz Skibinski Feb 09 '20 at 19:14
  • I don't have the same warning from spyder, but I still get the tremendously long error – Jacob Myer Feb 09 '20 at 19:16
  • 1
    I'm starting to thing - ```from pyspark.sql import functions as F``` is the problem here (which shouldn't be the case in general). You can check now - but if it's still not it, that some environmental discrepancy, which I just cannot reproduce... – Grzegorz Skibinski Feb 09 '20 at 19:26
  • if I remove the `.collect()` and print there is no error, but I am back to the `DataFrame[avg(Rate): double]` output – Jacob Myer Feb 09 '20 at 19:31
  • Because you haven't evaluated it. I.e. until ```.collect()``` nothing material really happened- your query is executed just once you call (```.collect()```, or ```.show()``` for instance) – Grzegorz Skibinski Feb 09 '20 at 19:39
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/207505/discussion-between-jacob-myer-and-grzegorz-skibinski). – Jacob Myer Feb 09 '20 at 19:41