3

I have some code written with PySpark, and I'm busy converting it to Scala. It's been going well, except now I'm struggling with user defined functions in Scala.

python

from pyspark.sql import SparkSession
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import functions as F

spark = SparkSession.builder.master('local[*]').getOrCreate()

a = spark.sparkContext.parallelize([(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)]).toDF(["index"]).withColumn("a1", F.lit(1)).withColumn("a2", F.lit(2)).withColumn("a3", F.lit(3))

a = a.select("index", F.struct(*('a' + str(c) for c in range(1, 4))).alias('a'))

a.show()

def a_to_b(a):
    # 1. check if a technical cure exists
    b = {}
    for i in range(1, 4):
        b.update({'b' + str(i): a[i - 1] ** 2})
    return b

a_to_b_udf = F.udf(lambda x: a_to_b(x), StructType(list(StructField("b" + str(x), IntegerType()) for x in range(1, 4))))

b = a.select("index", "a", a_to_b_udf(a.a).alias("b"))

b.show()

This yields:

+-----+-------+
|index|      a|
+-----+-------+
|    1|[1,2,3]|
|    2|[1,2,3]|
|    3|[1,2,3]|
|    4|[1,2,3]|
|    5|[1,2,3]|
|    6|[1,2,3]|
|    7|[1,2,3]|
|    8|[1,2,3]|
|    9|[1,2,3]|
|   10|[1,2,3]|
+-----+-------+

and

+-----+-------+-------+
|index|      a|      b|
+-----+-------+-------+
|    1|[1,2,3]|[1,4,9]|
|    2|[1,2,3]|[1,4,9]|
|    3|[1,2,3]|[1,4,9]|
|    4|[1,2,3]|[1,4,9]|
|    5|[1,2,3]|[1,4,9]|
|    6|[1,2,3]|[1,4,9]|
|    7|[1,2,3]|[1,4,9]|
|    8|[1,2,3]|[1,4,9]|
|    9|[1,2,3]|[1,4,9]|
|   10|[1,2,3]|[1,4,9]|
+-----+-------+-------+

Scala

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

// can ignore if running on spark-shell
val spark: SparkSession = SparkSession.builder()
  .master("local[*]")
  .getOrCreate()

import spark.implicits._

var a = spark.sparkContext.parallelize(1 to 10).toDF("index").withColumn("a1", lit(1)).withColumn("a2", lit(2)).withColumn("a3", lit(3))

// convert a{x} to struct column
a = a.select($"index", struct((1 to 3).map {x => col("a" + x)}.toList:_*).alias("a"))

a.show()

// this is where I am struggling, I have tried supplying a schema, but still get errors
val f = udf((a: Column) => {
  Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})

val b = a.select($"index", $"a", f($"a").alias("b"))

// throws the below error
b.show()

I can show() the first DataFrame, but I get a casting error when trying to show b.

The error is:

java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to org.apache.spark.sql.Column
  at $line23.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:31)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:370)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:246)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$4.apply(SparkPlan.scala:240)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:784)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:319)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:283)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
  at org.apache.spark.scheduler.Task.run(Task.scala:85)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
  at java.lang.Thread.run(Thread.java:745)

I have tried setting a schema for my UDF as I've done in Python, but I still get the same error.

Does anyone know how I can get around this issue? My example is simple, but what I need to do on the UDF is quite a lot of transformations before returning the struct.

nevi_me
  • 2,702
  • 4
  • 24
  • 37

1 Answers1

5

I feel very silly, because I've been struggling since Friday afternoon.

From Spark Sql UDF with complex input parameter,

struct types are converted to org.apache.spark.sql.Row

My problem was with the Column type that I was supplying to my function.

val f = udf((a: Column) => {
  Seq(Math.pow(a(0).asInstanceOf[Double], 2), Math.pow(a(1).asInstanceOf[Double], 2), Math.pow(a(2).asInstanceOf[Double], 2))
})

I was supposed to use Row instead.

val f = udf((a: Row) => {
  println("testing")
  Seq(Math.pow(a(0).asInstanceOf[Int], 2).asInstanceOf[Int],
    Math.pow(a(1).asInstanceOf[Int], 2).asInstanceOf[Int],
    Math.pow(a(2).asInstanceOf[Int], 2).asInstanceOf[Int])
})
Community
  • 1
  • 1
nevi_me
  • 2,702
  • 4
  • 24
  • 37