Search code examples
scalaapache-sparkfoldapache-spark-dataset

How does spark interprets type of a column in reduce


I have the following table

DEST_COUNTRY_NAME   ORIGIN_COUNTRY_NAME count
United States       Romania             15
United States       Croatia             1
United States       Ireland             344
Egypt               United States       15  

The table is represented as a Dataset.

scala> dataDS
res187: org.apache.spark.sql.Dataset[FlightData] = [DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string ... 1 more field]

The schema of dataDS is

scala> dataDS.printSchema;
root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)

I want to sum all the values of the count column. I suppose I can do it using the reduce method of Dataset.

I thought I could do the following but got error

scala> (dataDS.select(col("count"))).reduce((acc,n)=>acc+n);
<console>:38: error: type mismatch;
 found   : org.apache.spark.sql.Row
 required: String
       (dataDS.select(col("count"))).reduce((acc,n)=>acc+n);
                                                         ^

To make the code work, I had to explicitly specify that count is Int even though in the schema, it is an Int

scala> (dataDS.select(col("count").as[Int])).reduce((acc,n)=>acc+n);

Why did I have to explicitly specify type of count? Why Scala's type inference didn't work? In fact, the schema of the intermediate Dataset also infers count as a Int.

dataDS.select(col("count")).printSchema;
root
 |-- count: integer (nullable = true)

Solution

  • Just follow the types or look at the compiler messages.

    • You start with Dataset[FlightData].

    • You call it's select with col("count") as an argument. col(_) returns Column

    • The only variant of Dataset.select which takes Column returns DataFrame which is an alias for Dataset[Row].

    • There are two variants of Dataset.reduce one taking ReduceFunction[T] and the second (T, T) => T, where T is type constructor parameter of the Dataset, i.e. Dataset[T]. (acc,n)=>acc+n function is a Scala anonymous function, hence the second version apply.

    • Expanded:

      (dataDS.select(col("count")): Dataset[Row]).reduce((acc: Row, n: Row) => acc + n): Row
      

      which sets constraints - function takes Row and Row and returns Row.

    • Row has no + method, so the only option to satisfy

      (acc: ???, n: Row) => acc + n)
      

      is to use String (you can + Any to String.

      However this doesn't satisfy the complete expression - hence the error.

    • You've already figured out that you can use

      dataDS.select(col("count").as[Int]).reduce((acc, n) => acc + n)
      

      where col("count").as[Int] is a TypedColumn[Row, Int] and corresponding select returns Dataset[Int].

      Similarly you could

      dataDS.select(col("count")).as[Int].reduce((acc, n) => acc + n)
      

      and

      dataDS.toDF.map(_.getAs[Int]("count")).reduce((acc, n) => acc + n)
      

      In all cases

      .reduce((acc, n) => acc + n)
      

      being (Int, Int) => Int.