3

I have a data frame like the below:

+----+----+----+
|colA|colB|colC|
+----+----+----+
|1   |1   |23  |
|1   |2   |63  |
|1   |3   |null|
|1   |4   |32  |
|2   |2   |56  |
+----+----+----+

I apply the below instructions such that I create a sequence of values in column C:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
df.withColumn("colD", 
collect_list("colC").over(Window.partitionBy("colA").orderBy("colB")))

The result is like this such that column D is created and includes values of column C as a sequence while it has removed null value:

+----+----+----+------------+
|colA|colB|colC|colD        |
+----+----+----+------------+
|1   |1   |23  |[23]        |
|1   |2   |63  |[23, 63]    |
|1   |3   |null|[23, 63]    |
|1   |4   |32  |[23,63,32]  |
|2   |2   |56  |[56]        |
+----+----+----+------------+

However, I would like to keep null values in the new column and have the below result:

+----+----+----+-----------------+
|colA|colB|colC|colD             |
+----+----+----+-----------------+
|1   |1   |23  |[23]             |
|1   |2   |63  |[23, 63]         |
|1   |3   |null|[23, 63, null]   |
|1   |4   |32  |[23,63,null, 32] |
|2   |2   |56  |[56]             |
+----+----+----+-----------------+

As you see I still have null values in the result. Do you know how can I do it?

WhoAmI
  • 31
  • 1
  • 2

2 Answers2

7

As LeoC mentioned collect_list will drop null values. There seems to be a workaround to this behavior. By wrapping each scalar into array following by collect_list will result in [[23], [63], [], [32]] then when you do flatten on that you will get [23, 63,, 32]. Those missing values in arrays are nulls.

collect_list and flatten built-in sql functions I believe were introduced in Spark 2.4. I didn't look into implementation to verify this is expected behavior so I don't know how reliable this solution is.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._

val df = Seq(
  (Some(1), Some(1), Some(23)),
  (Some(1), Some(2), Some(63)),
  (Some(1), Some(3), None),
  (Some(1), Some(4), Some(32)),
  (Some(2), Some(2), Some(56))
).toDF("colA", "colB", "colC")

val newDf = df.withColumn("colD", flatten(collect_list(array("colC"))
    .over(Window.partitionBy("colA").orderBy("colB"))))


+----+----+----+-------------+
|colA|colB|colC|         colD|
+----+----+----+-------------+
|   1|   1|  23|         [23]|
|   1|   2|  63|     [23, 63]|
|   1|   3|null|    [23, 63,]|
|   1|   4|  32|[23, 63,, 32]|
|   2|   2|  56|         [56]|
+----+----+----+-------------+
marcin_koss
  • 5,763
  • 10
  • 46
  • 65
4

Since collect_list automatically removes all nulls, one approach would be to temporarily replace null with a designated number, say Int.MinValue, before applying the method, and use a UDF to restore those numbers back to null afterward:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._

val df = Seq(
  (Some(1), Some(1), Some(23)),
  (Some(1), Some(2), Some(63)),
  (Some(1), Some(3), None),
  (Some(1), Some(4), Some(32)),
  (Some(2), Some(2), Some(56))
).toDF("colA", "colB", "colC")

def replaceWithNull(n: Int) = udf( (arr: Seq[Int]) =>
  arr.map( i => if (i != n) Some(i) else None )
)

df.withColumn( "colD", replaceWithNull(Int.MinValue)(
    collect_list(when($"colC".isNull, Int.MinValue).otherwise($"colC")).
      over(Window.partitionBy("colA").orderBy("colB"))
  )
).show
// +----+----+----+------------------+
// |colA|colB|colC|              colD|
// +----+----+----+------------------+
// |   1|   1|  23|              [23]|
// |   1|   2|  63|          [23, 63]|
// |   1|   3|null|    [23, 63, null]|
// |   1|   4|  32|[23, 63, null, 32]|
// |   2|   2|  56|              [56]|
// +----+----+----+------------------+
Leo C
  • 22,006
  • 3
  • 26
  • 39
  • You don't need a UDF for that, you can use `transform` with `when` to implement the same logic and it will probably be faster. Depending on the data size this may or may not matter to you. https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.transform.html#pyspark.sql.functions.transform Also, be careful not to have the replacement value in your input dataset as it will corrupt the results. For example, in my case I have floats and I know there are no NaN values in the input so I can use it as a replacement, but you have to check first. – adamt06 Jun 05 '23 at 06:46