1

Suppose that we have a csv file which has been imported as a dataframe in PysPark as follows

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.read.csv("file path and name.csv", inferSchema = True, header = True)
df.show()

output

+-----+----+----+
|lable|year|val |
+-----+----+----+
|    A|2003| 5.0|
|    A|2003| 6.0|
|    A|2003| 3.0|
|    A|2004|null|
|    B|2000| 2.0|
|    B|2000|null|
|    B|2009| 1.0|
|    B|2000| 6.0|
|    B|2009| 6.0|
+-----+----+----+

Now, we want to add another column to df which contains the standard deviation of val based on the grouping on two columns lable and year. So, the output must be as follows:

+-----+----+----+-----+
|lable|year|val | std |
+-----+----+----+-----+
|    A|2003| 5.0| 1.53|
|    A|2003| 6.0| 1.53|
|    A|2003| 3.0| 1.53|
|    A|2004|null| null|
|    B|2000| 2.0| 2.83|
|    B|2000|null| 2.83|
|    B|2009| 1.0| 3.54|
|    B|2000| 6.0| 2.83|
|    B|2009| 6.0| 3.54|
+-----+----+----+-----+

I have the following codes which works for a small dataframe but it does not work for a very large dataframe (with about 40 million rows) which I am working with now.

import pyspark.sql.functions as f    
a = df.groupby('lable','year').agg(f.round(f.stddev("val"),2).alias('std'))
df = df.join(a, on = ['lable', 'year'], how = 'inner')

I get Py4JJavaError Traceback (most recent call last) error after running on my large dataframe.

Does anyone knows any alternative way? I hope your way works on my dataset.

I am using python3.7.1, pyspark2.4, and jupyter4.4.0

Monirrad
  • 465
  • 1
  • 7
  • 17

1 Answers1

0

The join on dataframe causes a lot of data shuffle between executors. In your case, you can do without the join. Use a window specification to partition data by 'lable' and 'year' and aggregate on the window.

from pyspark.sql.window import *

windowSpec = Window.partitionBy('lable','year')\
                   .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

df = df.withColumn("std", f.round(f.stddev("val").over(windowSpec), 2))
Manoj Singh
  • 1,627
  • 12
  • 21
  • Done. Could you please look at my other question in https://stackoverflow.com/questions/54192113/how-to-add-a-column-to-a-pyspark-dataframe-which-contains-the-nth-quantile-of-an?noredirect=1#comment95251427_54192113. You might have some answer. – Monirrad Jan 19 '19 at 19:30