1

I have a dataset like this

id    category     value
1     A            NaN
2     B            NaN
3     A            10.5
5     A            2.0
6     B            1.0

I want to fill the NAN values with the mean of their respective category. As shown below

id    category     value
1     A            4.16
2     B            0.5
3     A            10.5
5     A            2.0
6     B            1.0

I tried to calculate first mean values of each category using group by

val df2 = dataFrame.groupBy(category).agg(mean(value)).rdd.map{
      case r:Row => (r.getAs[String](category),r.get(1))
    }.collect().toMap
    println(df2)

I got map of each category and their respective mean values.output: Map(A ->4.16,B->0.5) Now i tried update query in Sparksql to fill column but it seems spqrkSql dosnt support update query. I tried to fill null values with in dataframe but failed to do so. What can i do? We can do the same in pandas as shown in Pandas: How to fill null values with mean of a groupby? But how can i do using spark dataframe

zero323
  • 322,348
  • 103
  • 959
  • 935
kush
  • 173
  • 1
  • 2
  • 11

3 Answers3

3

The simplest solution would be to use groupby and join:

 val df2 = df.filter(!(isnan($"value"))).groupBy("category").agg(avg($"value").as("avg"))
 df.join(df2, "category").withColumn("value", when(col("value").isNaN, $"avg").otherwise($"value")).drop("avg")

Note that if there is a category with all NaN it will be removed from the result

Assaf Mendelson
  • 12,701
  • 5
  • 47
  • 56
2

Indeed, you cannot update DataFrames, but you can transform them using functions like select and join. In this case, you can keep the grouping result as a DataFrame and join it (on category column) to the original one, then perform the mapping that would replace NaNs with the mean values:

import org.apache.spark.sql.functions._
import spark.implicits._

// calculate mean per category:
val meanPerCategory = dataFrame.groupBy("category").agg(mean("value") as "mean")

// use join, select and "nanvl" function to replace NaNs with the mean values:
val result = dataFrame
  .join(meanPerCategory, "category")
  .select($"category", $"id", nanvl($"value", $"mean")).show()
Tzach Zohar
  • 37,442
  • 3
  • 79
  • 85
  • To replace nulls you'll have to replace the `nanvl` function with `coalesce`. Or to handle both: `coalesce($"value", nanvl($"value", $"mean"))` – Tzach Zohar Feb 21 '17 at 11:12
  • Sorry that should be `coalesce(nanvl($"value", $"mean"), $"mean")` – Tzach Zohar Feb 21 '17 at 11:54
  • Y does import spark.implicits._ is not able to import. – kush Feb 21 '17 at 13:02
  • `spark` is the `SparkSession` - if it's named differently, replace the name; If you don't have a SparkSession you should have an `SQLContext` - import that context's implicits (e.g. `import sqlContext.implicits._` if it's named `sqlContext` – Tzach Zohar Feb 21 '17 at 13:04
  • it works gr8 for null values using coalesce($"value", $"mean"), $"mean")). But when i try coalesce(nanvl($"value", $"mean"), $"mean"), it doesn't fill null values too – kush Feb 21 '17 at 19:52
0

I stumbled upon same problem and came across this post. But tried a different solution i.e. using window functions. The code below is tested on pyspark 2.4.3 (Window functions are available from Spark 1.4). I believe this is bit cleaner solution. This post is quiet old, but hope this answer will be helpful for others.

from pyspark.sql import Window
from pyspark.sql.functions import *

df = spark.createDataFrame([(1,"A", None), (2,"B", None), (3,"A",10.5), (5,"A",2.0), (6,"B",1.0)], ['id', 'category', 'value'])

category_window = Window.partitionBy("category")
value_mean = mean("value0").over(category_window)

result = df\
  .withColumn("value0", coalesce("value", lit(0)))\
  .withColumn("value_mean", value_mean)\
  .withColumn("new_value", coalesce("value", "value_mean"))\
  .select("id", "category", "new_value")

result.show()

Output will be as expected (in question):

id  category    new_value       
1   A   4.166666666666667
2   B   0.5
3   A   10.5
5   A   2
6   B   1
nileshg
  • 23
  • 3