The title may not be very clear. Let me explain what I want to achieve with an example. Starting with a DataFrame as following:
val df = Seq((1, "CS", 0, (0.1, 0.2, 0.4, 0.5)),
(4, "Ed", 0, (0.4, 0.8, 0.3, 0.6)),
(7, "CS", 0, (0.2, 0.5, 0.4, 0.7)),
(101, "CS", 1, (0.5, 0.7, 0.3, 0.8)),
(5, "CS", 1, (0.4, 0.2, 0.6, 0.9)))
.toDF("id", "dept", "test", "array")
+---+----+----+--------------------+
| id|dept|test| array|
+---+----+----+--------------------+
| 1| CS| 0|[0.1, 0.2, 0.4, 0.5]|
| 4| Ed| 0|[0.4, 0.8, 0.3, 0.6]|
| 7| CS| 0|[0.2, 0.5, 0.4, 0.7]|
|101| CS| 1|[0.5, 0.7, 0.3, 0.8]|
| 5| CS| 1|[0.4, 0.2, 0.6, 0.9]|
+---+----+----+--------------------+
I want to remove/drop some elements of the array column according to the information in id, dept and test column. Specifically, the 4 elements in each array correspond to the four id that is in CS dept, and the number is generated with ascend id order (meaning 1, 5, 7, 101). Now I want remove the elements in each array that corresponds to the ids that have test column as 1. In this example, the 2nd and 4th elements will be removed and the end result will look like this:
+---+----+----+----------+
| id|dept|test| array|
+---+----+----+----------+
| 1| CS| 0|[0.1, 0.4]|
| 4| Ed| 0|[0.4, 0.3]|
| 7| CS| 0|[0.2, 0.4]|
|101| CS| 1|[0.5, 0.3]|
| 5| CS| 1|[0.4, 0.6]|
+---+----+----+----------+
In order to avoid collecting all the results and do the manipulation in Scala. I would like to keep the operation in Spark DataFrame if possible. My thought to tackle this problem includes Two steps:
- Figure out the Index of array elements that need to be removed
- Apply the remove/drop operation
So far, I think I have figured out step 1 as following:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val w = Window.partitionBy("dept").orderBy("id")
val studentIdIdx = df.select("id", "dept")
.withColumn("Index", row_number().over(w))
.where("dept = 'CS'").drop("dept")
studentIdIdx.show()
+---+-----+
| id|Index|
+---+-----+
| 1| 1|
| 5| 2|
| 7| 3|
|101| 4|
+---+-----+
val testIds = df.where("test = 1")
.select($"id".as("test_id"))
val testMask = studentIdIdx
.join(testIds, studentIdIdx("id") === testIds("test_id"))
.drop("id","test_id")
testMask.show()
+-----+
|Index|
+-----+
| 2|
| 4|
+-----+
So my two related questions are:
How to apply the remove/drop function to each array in each row with the Index? (I am open to suggestion for a better way to figure the Index as well)
The real final DataFrame that I want should remove some more element on top of the above result. Specifically, for test=0 & dept=CS, it should remove the array element that correspond to the Index of the id. In this example, the 1st element in the row with id=1 and the 3rd element (original index before any removal) in the row with id=7 should be removed, and the real final result is:
+---+----+----+----------+ | id|dept|test| array| +---+----+----+----------+ | 1| CS| 0|[0.4] | | 4| Ed| 0|[0.4, 0.3]| | 7| CS| 0|[0.2] | |101| CS| 1|[0.5, 0.3]| | 5| CS| 1|[0.4, 0.6]| +---+----+----+----------+
I mention the second point just in case there is a more efficient way can be applied to achieve both remove operations together. If not, I think I should be able to figure out how to do the second remove once I know how to use the Index information for remove operation. Thanks!