1
  • Spark 2.2.0

I have the following code converted from SQL script. It has been running for two hours and it's still running. Even slower than SQL Server. Is anything not done correctly?

The following is the plan,

  1. Push table2 to all executors
  2. Partition table1 and distribute the partitions to executors.
  3. And each row in table2/t2 joins (cross join) each partition of table1.

So the calculation on the result of the cross-join can be run distributed/parallelly. (I wanted to, for example suppose​ I have 16 executors, keep a copy of t2 on all the 16 executors. Then divide table 1 into 16 partitions, one for each executor. Then each executor do the calculation on one partition of table 1 and t2.)

case class Cols (Id: Int, F2: String, F3: BigDecimal, F4: Date, F5: String,
                 F6: String, F7: BigDecimal, F8: String, F9: String, F10: String )
case class Result (Id1: Int, ID2: Int, Point: Int)

def getDataFromDB(source: String) = {
    import sqlContext.sparkSession.implicits._

    sqlContext.read.format("jdbc").options(Map(
      "driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver",
      "url" -> jdbcSqlConn,
      "dbtable" -> s"$source"
    )).load()
      .select("Id", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10")
      .as[Cols]
  }

val sc = new SparkContext(conf)
val table1:DataSet[Cols] = getDataFromDB("table1").repartition(32).cache()
println(table1.count()) // about 300K rows

val table2:DataSet[Cols] = getDataFromDB("table2") // ~20K rows
table2.take(1)
println(table2.count())
val t2 = sc.broadcast(table2)

import org.apache.spark.sql.{functions => func}
val j = table1.joinWith(t2.value, func.lit(true))

j.map(x => {
  val (l, r) = x
  Result(l.Id, r.Id, 
  (if (l.F1!= null && r.F1!= null && l.F1== r.F1) 3 else 0)
  +(if (l.F2!= null && r.F2!= null && l.F2== r.F2) 2 else 0)
  + ..... // All kind of the similiar expression
  +(if (l.F8!= null && r.F8!= null && l.F8== r.F8) 1 else 0)
  )
}).filter(x => x.Value >= 10)
println("Total count %d", j.count()) // This takes forever, the count will be about 100

How to rewrite it with Spark idiomatic way?

Ref: https://forums.databricks.com/questions/6747/how-do-i-get-a-cartesian-product-of-a-huge-dataset.html

Jacek Laskowski
  • 72,696
  • 27
  • 242
  • 420
ca9163d9
  • 27,283
  • 64
  • 210
  • 413
  • I don't get why you broadcast t2 and then join on t2.value. Is there some optimisation going on when you do that on a small dataset (I had not even thought of broadcasting datasets in their uncollected forms) ? I don't see the point of this broadcast at all... Care to explain ? – GPI Jul 19 '17 at 07:57
  • I wanted to, for example I have 16 executors, keep a copy of t2 on all the 16 executors. Then divide table 1 into 16 partitions, one for each executor. Then each executor do the calculation on one partition of table 1 and t2. – ca9163d9 Jul 19 '17 at 08:05
  • That does not work this way. If you want to copy t2 to all execturos you have to collect it (you'll get an Array), then broadcast it, but you won't be able to join (you join on DataSet, not on simple arrays).Your itended implementation should rather look like `t2 = sc.broadcast(getData("table2").collect)` and then you'd replace your cross join as a `mapPartition` on the `table1` dataset – GPI Jul 19 '17 at 08:21
  • Will my current code distributes the workload to all the executors? Or it's run in one executor using one CPU core? – ca9163d9 Jul 19 '17 at 13:58

1 Answers1

2

(Somehow I feel as if I have seen the code already)

The code is slow because you use just a single task to load the entire dataset from the database using JDBC and despite cache it does not benefit from it.

Start by checking out the physical plan and Executors tab in web UI to find out about the single executor and the single task to do the work.

You should use one of the following to fine-tune the number of tasks for loading:

  1. Use partitionColumn, lowerBound, upperBound options for the JDBC data source
  2. Use predicates option

See JDBC To Other Databases in Spark's official documentation.

After you're fine with the loading, you should work on improving the last count action and add...another count action right after the following line:

val table1: DataSet[Cols] = getDataFromDB("table1").repartition(32).cache()
// trigger caching as it's lazy in Dataset API
table1.count

The reason why the entire query is slow is that you only mark table1 to be cached when an action gets executed which is exactly at the end (!) In other words, cache does nothing useful and more importantly makes the query performance even worse.

Performance will increase after you table2.cache.count too.

If you want to do cross join, use crossJoin operator.

crossJoin(right: Dataset[_]): DataFrame Explicit cartesian join with another DataFrame.

Please note the note from the scaladoc of crossJoin (no pun intended).

Cartesian joins are very expensive without an extra filter that can be pushed down.

The following requirement is already handled by Spark given all the optimizations available.

So the calculation on the result of the cross-join can be run distributed/parallelly.

That's Spark's job (again, no pun intended).

The following requirement begs for broadcast.

I wanted to, for example suppose​ I have 16 executors, keep a copy of t2 on all the 16 executors. Then divide table 1 into 16 partitions, one for each executor. Then each executor do the calculation on one partition of table 1 and t2.)

Use broadcast function to hint Spark SQL's engine to use table2 in broadcast mode.

broadcast[T](df: Dataset[T]): Dataset[T] Marks a DataFrame as small enough for use in broadcast joins.

Jacek Laskowski
  • 72,696
  • 27
  • 242
  • 420
  • I added `table2.take(1) println(table2.count())` and it's still very slow. – ca9163d9 Jul 19 '17 at 16:54
  • @dc7a9163d9 Where's `table2.cache`? Remove `val t2 = sc.broadcast(table2)` and stop using `broadcast`. You can use `broadcast` function instead to hint Spark SQL to use broadcast join. – Jacek Laskowski Jul 19 '17 at 17:29
  • I removed `val t2=sc.broadcast(table2)` and called cache() for table2 `val table2:DataSet[Cols] = getDataFromDB("table2") .cache()`. It reduced to 1.5 hours. Anyway to make it faster? Should I try `mapPartition` with `broadcast` as the comments in the main questoin? – ca9163d9 Jul 19 '17 at 22:17
  • I'm using Spark 2.2. What will be in the `.where()`? I'm doing a cross join and don't have any join condition. the `.filter() cannot be called before the `.map()` because it uses the result of the mapping. However, since it's lazy and the order shouldn't matter? – ca9163d9 Jul 20 '17 at 14:14
  • Hope the additions, i.e. `crossJoin` and `broadcast`, should help you making the query faster. – Jacek Laskowski Jul 20 '17 at 16:34