0

Say I have following table/dataframe:

Id Col1 Col2 Col3
1 100 aaa xxx
2 200 aaa yyy
3 300 ccc zzz

I need to calculate an extra column CalculatedValue which could have one or multiple values based on other columns' values.

I have tried with a regular CASE WHEN statement like:

df_out = (df_source
    .withColumn('CalculatedValue',
        expr("CASE WHEN Col1 = 100 THEN 'AAA111'
              WHEN Col2 = 'aaa' then 'BBB222'
              WHEN Col3 = 'zzz' then 'CCC333'
              END")
    )

Note I'm doing it with expr() because the actual CASE WHEN statement is a very long string built dynamically.

This results in a table/dataframe like this:

Id Col1 Col2 Col3 CalculatedValue
1 100 aaa xxx AAA111
2 200 aaa yyy BBB222
3 300 ccc zzz CCC333

However what I need looks more like this, where the CASE WHEN statement didn't stop evaluating after the first match, and instead evaluated all conditions and accumulated all matches into, say, an array

Id Col1 Col2 Col3 CalculatedValue
1 100 aaa xxx [AAA111, BBB222]
2 200 aaa yyy BBB222
3 300 ccc zzz CCC333

Any ideas? Thanks

Martin
  • 78
  • 5

2 Answers2

0
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, when}
import spark.implicits._
val df = Seq(
  (1, 100, "aaa", "xxx"),
  (2, 200, "aaa", "yyy"),
  (3, 300, "ccc", "zzz")
).toDF("Id", "Col1", "Col2", "Col3")

val resDF = df
  .withColumn(
    "CalculatedValue",
    when(
      col("Col1") === 100 && col("Col2") === "aaa" && col("Col3") === "zzz",
      Array("AAA111", "BBB222", "CCC333")
    ).when(
        col("Col1") === 100 && col("Col2") === "aaa" && col(
          "Col3"
        ) =!= "zzz",
        Array("AAA111", "BBB222")
      )
      .when(
        col("Col1") === 100 && col("Col2") =!= "aaa" && col(
          "Col3"
        ) === "zzz",
        Array("AAA111", "CCC333")
      )
      .when(
        col("Col1") =!= 100 && col("Col2") === "aaa" && col(
          "Col3"
        ) === "zzz",
        Array("BBB222", "CCC333")
      )
      .when(
        col("Col1") =!= 100 && col("Col2") =!= "aaa" && col(
          "Col3"
        ) === "zzz",
        Array("CCC333")
      )
      .when(
        col("Col1") === 100 && col("Col2") =!= "aaa" && col(
          "Col3"
        ) =!= "zzz",
        Array("AAA111")
      )
      .when(
        col("Col1") =!= 100 && col("Col2") === "aaa" && col(
          "Col3"
        ) =!= "zzz",
        Array("BBB222")
      )
      .otherwise(Array("unknown"))
  )
resDF.show(false)
/*
+---+----+----+----+----------------+
|Id |Col1|Col2|Col3|CalculatedValue |
+---+----+----+----+----------------+
|1  |100 |aaa |xxx |[AAA111, BBB222]|
|2  |200 |aaa |yyy |[BBB222]        |
|3  |300 |ccc |zzz |[CCC333]        |
+---+----+----+----+----------------+
*/
mvasyliv
  • 1,214
  • 6
  • 10
  • Hey, thanks for taking the time. That doesn't seem it will work for me unfortunately. The actual case-when/when-otherwise statement is very long (dynamically built from a SAP table) and making all the combinations would probably make it insanely complex – Martin Nov 30 '22 at 10:19
0

In the end I just adjusted the case when statement as:

df_out = (df_source
.withColumn('CalculatedValue',
    expr("CONCAT(CASE WHEN Col1 = 100 THEN 'AAA111,' ELSE '' END,
          WHEN Col2 = 'aaa' then 'BBB222,' ELSE '' END,
          WHEN Col3 = 'zzz' then 'CCC333,' ELSE '' END)")
)

This way I end up with a value like AAA111,BBB222, in the example I provided. Then I remove the trailing comma and explode() the list into rows.

Martin
  • 78
  • 5