-1

I have the following table

DEST_COUNTRY_NAME   ORIGIN_COUNTRY_NAME count
United States       Romania             15
United States       Croatia             1
United States       Ireland             344
Egypt               United States       15  

The table is represented as a Dataset.

scala> dataDS
res187: org.apache.spark.sql.Dataset[FlightData] = [DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string ... 1 more field]

The schema of dataDS is

scala> dataDS.printSchema;
root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)

I want to sum all the values of the count column. I suppose I can do it using the reduce method of Dataset.

I thought I could do the following but got error

scala> (dataDS.select(col("count"))).reduce((acc,n)=>acc+n);
<console>:38: error: type mismatch;
 found   : org.apache.spark.sql.Row
 required: String
       (dataDS.select(col("count"))).reduce((acc,n)=>acc+n);
                                                         ^

To make the code work, I had to explicitly specify that count is Int even though in the schema, it is an Int

scala> (dataDS.select(col("count").as[Int])).reduce((acc,n)=>acc+n);

Why did I have to explicitly specify type of count? Why Scala's type inference didn't work? In fact, the schema of the intermediate Dataset also infers count as a Int.

dataDS.select(col("count")).printSchema;
root
 |-- count: integer (nullable = true)
Manu Chadha
  • 15,555
  • 19
  • 91
  • 184
  • Maybe the problem is that it's nullable. So, int[nullable] is not the same as int and that makes sense as it's not clear how to sum int with null. – Mikhail Berlinkov Feb 09 '19 at 16:11

2 Answers2

3

I think you need to do it in another way. I will assume FlightData is case class with the above schema. So, the solution is using the map and reduce as below

val totalSum = dataDS.map(_.count).reduce(_+_) //this line replace the above error as col("count") can't be selected.

Updated: The issue of inference is not related to the dataset, Actually, when you use select you will work on Dataframe(same if you join) which is not statically typed schema and you will lose the feature of your case class. For example, the type of select will be Dataframe not Dataset so, you will not be able to infer the type.

val x: DataFrame = dataDS.select('count)
val x: Dataset[Int] = dataDS.map(_.count)

Also, from this Answer To get a TypedColumn from Column you simply use myCol.as[T].

I did a simple example to reproduce the code and the data.

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object EntryMainPoint extends App {

  //val warehouseLocation = "file:${system:user.dir}/spark-warehouse"
  val spark = SparkSession
    .builder()
    .master("local[*]")
    .appName("SparkSessionZipsExample")
    //.config("spark.sql.warehouse.dir", warehouseLocation)
    .getOrCreate()

  val someData = Seq(
    Row("United States", "Romania", 15),
    Row("United States", "Croatia", 1),
    Row("United States", "Ireland", 344),
    Row("Egypt", "United States", 15)
  )


  val flightDataSchema = List(
    StructField("DEST_COUNTRY_NAME", StringType, true),
    StructField("ORIGIN_COUNTRY_NAME", StringType, true),
    StructField("count", IntegerType, true)
  )

  case class FlightData(DEST_COUNTRY_NAME: String, ORIGIN_COUNTRY_NAME: String, count: Int)
  import spark.implicits._

  val dataDS = spark.createDataFrame(
    spark.sparkContext.parallelize(someData),
    StructType(flightDataSchema)
  ).as[FlightData]

  val totalSum = dataDS.map(_.count).reduce(_+_) //this line replace the above error as col("count") can't be selected.
  println("totalSum = " + totalSum)


  dataDS.printSchema()
  dataDS.show()


}

Output below

totalSum = 375

root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|   15|
|    United States|            Croatia|    1|
|    United States|            Ireland|  344|
|            Egypt|      United States|   15|
+-----------------+-------------------+-----+

Note: You can do a selection from the dataset using the below way

val countColumn = dataDS.select('count) //or map(_.count)

You can also have a look in this reduceByKey in Spark Dataset

Moustafa Mahmoud
  • 1,540
  • 13
  • 35
  • You can also have a look on the implementation https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala – Moustafa Mahmoud Feb 09 '19 at 19:08
2

Just follow the types or look at the compiler messages.

  • You start with Dataset[FlightData].

  • You call it's select with col("count") as an argument. col(_) returns Column

  • The only variant of Dataset.select which takes Column returns DataFrame which is an alias for Dataset[Row].

  • There are two variants of Dataset.reduce one taking ReduceFunction[T] and the second (T, T) => T, where T is type constructor parameter of the Dataset, i.e. Dataset[T]. (acc,n)=>acc+n function is a Scala anonymous function, hence the second version apply.

  • Expanded:

    (dataDS.select(col("count")): Dataset[Row]).reduce((acc: Row, n: Row) => acc + n): Row
    

    which sets constraints - function takes Row and Row and returns Row.

  • Row has no + method, so the only option to satisfy

    (acc: ???, n: Row) => acc + n)
    

    is to use String (you can + Any to String.

    However this doesn't satisfy the complete expression - hence the error.

  • You've already figured out that you can use

    dataDS.select(col("count").as[Int]).reduce((acc, n) => acc + n)
    

    where col("count").as[Int] is a TypedColumn[Row, Int] and corresponding select returns Dataset[Int].

    Similarly you could

    dataDS.select(col("count")).as[Int].reduce((acc, n) => acc + n)
    

    and

    dataDS.toDF.map(_.getAs[Int]("count")).reduce((acc, n) => acc + n)
    

    In all cases

    .reduce((acc, n) => acc + n)
    

    being (Int, Int) => Int.

Community
  • 1
  • 1