0

In trying to make this solution to Perform a typed join in Scala with Spark Datasets available as an implicit, I've run across something I don't understand.

In the test below, the signature of innerJoin is def innerJoin[U, K](ds2: Dataset[U])(f: T => K, g: U => K)(implicit e1: Encoder[(K, T)], e2: Encoder[(K, U)], e3: Encoder[(T, U)]), but I call it with f: Foo => String and g: Bar => Int. I would expect an error at compile time, but it compiles just fine. Why is that?

What actually happens is it compiles just fine and the test fails with java.lang.ClassNotFoundException: scala.Any when Spark tries to create a product encoder (for the resulting ((K, Foo),(K, Bar)) tuples, I think). I assume the Any appears as the common "parent" of Int and String.

import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import org.scalatest.Matchers
import org.scalatest.testng.TestNGSuite
import org.testng.annotations.Test

case class Foo(a: String)

case class Bar(b: Int)

class JoinTest extends TestNGSuite with Matchers {
  import JoinTest._

  @Test
  def testJoin(): Unit = {
    val spark = SparkSession.builder()
      .master("local")
      .appName("test").getOrCreate()

    import spark.implicits._

    val ds1 = spark.createDataset(Seq(Foo("a")))
    val ds2 = spark.createDataset(Seq(Bar(123)))

    val jd = ds1.innerJoin(ds2)(_.a, _.b)

    jd.count shouldBe 0
  }
 }

object JoinTest {
  implicit class Joins[T](ds1: Dataset[T]) {
    def innerJoin[U, K](ds2: Dataset[U])(f: T => K, g: U => K)
     (implicit e1: Encoder[(K, T)], e2: Encoder[(K, U)], e3: Encoder[(T, U)]): Dataset[(T, U)] = 
     {
       val ds1_ = ds1.map(x => (f(x), x))
       val ds2_ = ds2.map(x => (g(x), x))
       ds1_.joinWith(ds2_, ds1_("_1") === ds2_("_1")).map(x => (x._1._2, x._2._2))
    }
   }
}
hoyland
  • 1,776
  • 14
  • 14

1 Answers1

1

You're correct that Any is getting inferred as the common parent of String and Int and so used as K. Function is covariant in the output type. So a Foo => String is a valid subclass of Foo => Any.

The common way to fix this kind of thing is to use two type parameters and an implicit =:=. For instance:

def innerJoin[U, K1, K2](ds2: Dataset[U])(f: T => K1, g: U => K2)
  (implicit eq: K1 =:= K2, e1: Encoder[(K2, T)], e2: Encoder[(K2, U)], e3: Encoder[(T, U)]): Dataset[(T, U)] = 
  {
    val ds1_ = ds1.map(x => (eq(f(x)), x))
    ... rest the same as before ...
Joe K
  • 18,204
  • 2
  • 36
  • 58