3

I like to write a function that handles data skew when joining two Spark Datasets.

The solution for DataFrames is straightforward:

def saltedJoin(left: DataFrame, right: DataFrame, e: Column, kind: String = "inner", replicas: Int): DataFrame = {
    val saltedLeft = left.
      withColumn("__temporarily__", typedLit((0 until replicas).toArray)).
      withColumn("__skew_left__", explode($"__temporarily__")).
      drop($"__temporarily__").
      repartition($"__skew_left__")

    val saltedRight = right.
      withColumn("__temporarily__", rand).
      withColumn("__skew_right__", ($"__temporarily__" * replicas).cast("bigint")).
      drop("__temporarily__").
      repartition($"__skew_right__")

    saltedLeft.
      join(saltedRight, $"__skew_left__" === $"__skew_right__" && e, kind).
      drop($"__skew_left__").
      drop($"__skew_right__")
  }

And you use the function like this:

val joined = saltedJoin(df alias "l", df alias "r", $"l.x" === $"r.x", replicas = 5)

However, I don't know how to write the join function for Dataset instances. So far, I have written the following:

def saltedJoinWith[A: Encoder : TypeTag, B: Encoder : TypeTag](left: Dataset[A],
                                             right: Dataset[B],
                                             e: Column,
                                             kind: String = "inner",
                                             replicas: Int): Dataset[(A, B)] = {
    val spark = left.sparkSession
    val random = new Random()
    import spark.implicits._

    val saltedLeft: Dataset[(A, Int)] = left flatMap (a => 0 until replicas map ((a, _)))
    val saltedRight: Dataset[(B, Int)] = right map ((_, random.nextInt(replicas)))

    saltedLeft.joinWith(saltedRight, saltedLeft("_2") === saltedRight("_2") && e, kind).map(x => (x._1._1, x._2._1))
  }

This is obviously not the correct solution as the join condition e does not point to the columns defined in saltedRight and saltedLeft. It points to the columns in saltedRight._1 and saltedLeft._1. So, for instance, val j = saltedJoinWith(ds alias "l", ds alias "r", $"l.x" === $"r.x", replicas = 5) will fail in runtime with the following exception:

org.apache.spark.sql.AnalysisException: cannot resolve '`l.x`' given input columns: [_1, _2];;

I am using Apache Spark 2.2.

Ashkan
  • 1,643
  • 5
  • 23
  • 45
  • What if you convert a Dataset to Dataframe inside the function and carry on the usual steps. – skdhfgeq2134 Mar 12 '19 at 11:16
  • I thought about that. How do you convert the resulting Dataframe to a Dataset of tuples? – Ashkan Mar 12 '19 at 22:19
  • You can create a case class before hand and apply that case class on resulting DF like `df.as(class)` – skdhfgeq2134 Mar 13 '19 at 06:15
  • @Ashkan . How this condition will work "join(......, $"__skew_left__" === $"__skew_right__" )" - From left table it have 1,2,3.,4 and 5 and from skew_right some random number * (1,2,3,4,5,6)..How it will match in join condition. – Learn Hadoop Jan 14 '20 at 14:35
  • @LearnHadoop. From the left table, it will have 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5, 5, and from the right table it will have 1, 2, 3, 4, 5. We are guaranteed that at least one row from the right table will be matches to the left table. So it will work as expected. – Ashkan Feb 20 '20 at 20:44
  • @Ashkan can you help on this https://stackoverflow.com/questions/73866552/sparkcontext-was-shut-down-while-doing-a-join – Shasu Sep 27 '22 at 11:28

0 Answers0