As @m_vemuri pointed out in the comments, there will be some performance impacts to keeping track of what it filtered out.
After some experimentation the best method I can think of is to start off with tracking which values meet your filtering conditions, and then filter these out at the end (without having to re-run your filters).
This isn't as elegant as calling .filter([filter1]).filter([filter2]).filter([filter3])
as in your Question, but is the most concise I can conceive.
TL;DR
val df = (1 to 100).toDF("n")
def firstFilter = col("n") % 2 === 0
def secondFilter = col("n") % 4 === 0
def thirdFilter = col("n") % 8 === 0
val filteredDF = df
.withColumn("firstFilter", firstFilter)
.withColumn("secondFilter", secondFilter)
.withColumn("thirdFilter", thirdFilter)
import org.apache.spark.sql.DataFrame
def countFilter(df: DataFrame, c: String): DataFrame =
df.filter(col(c) === false)
.groupBy(col(c))
.agg(count(c).as(s"filterCount"))
.withColumn("filter", lit(c))
.select("filter", "filterCount")
def countFilters(df: DataFrame, cs: String*): DataFrame =
cs.tail
.foldLeft(countFilter(df, cs.head))(
(acc, c) => acc.union(countFilter(df, c))
)
// produces a DataFrame with a count, per filter, of the rows removed
countFilters(filteredDF, "firstFilter", "secondFilter", "thirdFilter")
.show()
filteredDF
.filter(col("firstFilter") &&
col("secondFilter") &&
col("thirdFilter"))
.select("n")
.show(5)
gives:
+------------+-----------+
| filter|filterCount|
+------------+-----------+
| firstFilter| 50|
|secondFilter| 75|
| thirdFilter| 88|
+------------+-----------+
+---+
| n|
+---+
| 8|
| 16|
| 24|
| 32|
| 40|
+---+
only showing top 5 rows
Full solution
Starting with some example data:
val df = (1 to 100).toDF("n")
df.show(5)
gives:
+---+
| n|
+---+
| 1|
| 2|
| 3|
| 4|
| 5|
+---+
only showing top 5 rows
Then with three filters we create three new columns, tracking the results against each filter (you could alternatively only have one of these columns, by overwriting it each time, combining the results):
def firstFilter = col("n") % 2 === 0
def secondFilter = col("n") % 4 === 0
def thirdFilter = col("n") % 8 === 0
val filteredDF = df
.withColumn("firstFilter", firstFilter)
.withColumn("secondFilter", secondFilter)
.withColumn("thirdFilter", thirdFilter)
filteredDF.show(8)
gives:
+---+-----------+------------+-----------+
| n|firstFilter|secondFilter|thirdFilter|
+---+-----------+------------+-----------+
| 1| false| false| false|
| 2| true| false| false|
| 3| false| false| false|
| 4| true| true| false|
| 5| false| false| false|
| 6| true| false| false|
| 7| false| false| false|
| 8| true| true| true|
+---+-----------+------------+-----------+
only showing top 8 rows
At this point we still have all of the original rows, however we have already run each filter, storing the results in new columns with true
/false
values.
Now we can count the number of rows that did not match each condition:
import org.apache.spark.sql.DataFrame
def countFilter(df: DataFrame, c: String): DataFrame =
df.filter(col(c) === false)
.groupBy(col(c))
.agg(count(c).as(s"filterCount"))
.withColumn("filter", lit(c))
.select("filter", "filterCount")
def countFilters(df: DataFrame, cs: String*): DataFrame =
cs.tail
.foldLeft(countFilter(df, cs.head))(
(acc, c) => acc.union(countFilter(df, c))
)
// produces a DataFrame with a count, per filter, of the rows removed
countFilters(filteredDF, "firstFilter", "secondFilter", "thirdFilter")
.show()
gives:
+------------+-----------+
| filter|filterCount|
+------------+-----------+
| firstFilter| 50|
|secondFilter| 75|
| thirdFilter| 88|
+------------+-----------+
and we can apply the filtering:
filteredDF
.filter(col("firstFilter") &&
col("secondFilter") &&
col("thirdFilter"))
.select("n")
.show(5)
gives:
+---+
| n|
+---+
| 8|
| 16|
| 24|
| 32|
| 40|
+---+
only showing top 5 rows
Discussion
Won't this impact performance pretty severely to add a count for each filter? Especially for a high volume pipeline
The counts being performed are on boolean columns, and only counting the false
values at that. Even for a large Dataset these counts should scale, and perform, well.
There naturally will be some overhead when tracking the number of rows filtered, which is something Spark doesn't expose out of the box.
What if I create a custom set accumulator and add to it the row IDs in the transformation (like in the answer I linked), then get the set size in the driver?
The answer you linked is a way of solving the problem, yes; this involves using RDD and custom functions however. My solution uses the built in Spark functions which can be better optimized.
I would recommend running both against your dataset, if possible, and compare the results.
..won't this solve spark's re running of transformations problem?
I don't see a re-running problem here, as my solution doesn't repeat anything, instead splitting the filtering process into two steps:
- Calculate a
true
/false
value as to whether each row satisfies each filter. (Can be fully parallelised)
- Aggregate the counts of
false
values per filter.
- Perform each filter without rerunning the filter logic (using the
true
/false
values already calculated. (Can be fully parallelised)