Search code examples
apache-sparkapache-spark-sqlbigdecimal

Why does Spark groupBy.agg(min/max) of BigDecimal always return 0?


I'm trying to group by one column of a DataFrame, and generate the min and max values of a BigDecimal column within each of the resulting groups. The results always produce a very small (approximately 0) value.

(Similar min/max calls against a Double column produce the expected, non-zero values.)

As a simple example:

If I create the following DataFrame:

import org.apache.spark.sql.{functions => f}

case class Foo(group: String, bd_value: BigDecimal, d_value: Double)

val rdd = spark.sparkContext.parallelize(Seq(
  Foo("A", BigDecimal("1.0"), 1.0),
  Foo("B", BigDecimal("10.0"), 10.0),
  Foo("B", BigDecimal("1.0"), 1.0),
  Foo("C", BigDecimal("10.0"), 10.0),
  Foo("C", BigDecimal("10.0"), 10.0),
  Foo("C", BigDecimal("10.0"), 10.0)
))

val df = rdd.toDF()

Selecting max of either the Double or BigDecimal column returns the expected result:

df.select(f.max("d_value")).show()

// +------------+
// |max(d_value)|
// +------------+
// |        10.0|
// +------------+

df.select(f.max("bd_value")).show()

// +--------------------+
// |       max(bd_value)|
// +--------------------+
// |10.00000000000000...|
// +--------------------+

But if I group-by then aggregate, I get a reasonable result for the Double column, but near-zero values for the BigDecimal column:

df.groupBy("group").agg(f.max("d_value")).show()

// +-----+------------+
// |group|max(d_value)|
// +-----+------------+
// |    B|        10.0|
// |    C|        10.0|
// |    A|         1.0|
// +-----+------------+

df.groupBy("group").agg(f.max("bd_value")).show()

// +-----+-------------+
// |group|max(bd_value)|
// +-----+-------------+
// |    B|     1.00E-16|
// |    C|     1.00E-16|
// |    A|      1.0E-17|
// +-----+-------------+

Why does spark return a zero result for these min/max calls?


Solution

  • TL;DR

    There seems to be an inconsistency in how Spark treats the scale of BigDecimals that manifests in the particular case shown in the question. The code behaves as though it is converting BigDecimals to unscaled Longs using the scale of the BigDecimal object, but then converting back to BigDecimal using the scale of the schema.

    This can be worked around by either

    • explicitly setting the scale on all BigDecimal values to match the DataFrame's schema using setScale, or
    • manually specifying a schema and creating the DF from an RDD[Row]

    Long Version

    Here is what I think is happening on my machine with Spark 2.4.0.

    In the groupBy.max case, Spark is going through UnsafeRow and converting the BigDecimal to an unscaled Long and storing it as a Byte array in setDecimal at this line (as verified with print statements). Then, when it later calls getDecimal, it converts the byte array back to a BigDecimal using the scale specified in the schema.

    If the scale in the original value does not match the scale in the schema, this results in an incorrect value. For example,

    val foo = BigDecimal(123456)
    foo.scale
    0
    
    val bytes = foo.underlying().unscaledValue().toByteArray()
    
    // convert the bytes into BigDecimal using the original scale -- correct value
    val sameValue = BigDecimal(new java.math.BigInteger(bytes), 0)
    sameValue: scala.math.BigDecimal = 123456
    
    // convert the bytes into BigDecimal using scale 18 -- wrong value
    val smaller = BigDecimal(new java.math.BigInteger(bytes), 18)
    smaller: scala.math.BigDecimal = 1.23456E-13
    
    

    If I just select the max of the bd_value column, Spark doesn't seem to go through setDecimal. I haven't verified why, or where it goes instead.

    But, this would explain the values observed in the question. Using the same case class Foo:

    // This BigDecimal has scale 0
    val rdd = spark.sparkContext.parallelize(Seq(Foo("C", BigDecimal(123456), 123456.0)))
    
    // And shows with scale 0 in the DF
    rdd.toDF.show
    +-----+--------+--------+
    |group|bd_value| d_value|
    +-----+--------+--------+
    |    C|  123456|123456.0|
    +-----+--------+--------+
    
    // But the schema has scale 18
    rdd.toDF.printSchema
    root
     |-- group: string (nullable = true)
     |-- bd_value: decimal(38,18) (nullable = true)
     |-- d_value: double (nullable = false)
    
    
    // groupBy + max corrupts in the same way as converting to bytes via unscaled, then to BigDecimal with scale 18
    rdd.groupBy("group").max("bd_value").show
    +-----+-------------+
    |group|max(bd_value)|
    +-----+-------------+
    |    C|  1.23456E-13|
    +-----+-------------+
    
    // This BigDecimal is forced to have the same scale as the inferred schema
    val rdd = spark.sparkContext.parallelize(Seq(Foo("C",BigDecimal(123456).setScale(18), 123456.0)))
    
    // verified the scale is 18 in the DF
    +-----+--------------------+--------+
    |group|            bd_value| d_value|
    +-----+--------------------+--------+
    |    C|123456.0000000000...|123456.0|
    +-----+--------------------+--------+
    
    
    // And it works as expected
    rdd1.groupBy("group").max("bd_value").show
    
    +-----+--------------------+
    |group|       max(bd_value)|
    +-----+--------------------+
    |    C|123456.0000000000...|
    +-----+--------------------+
    
    

    This would also explain why, as observed in the comment, it works fine when converted from an RDD[Row] with an explicit schema.

    val rdd2 = spark.sparkContext.parallelize(Seq(Row("C", BigDecimal(123456), 123456.0)))
    
    // schema has BigDecimal scale 18
    val schema = StructType(Seq(StructField("group", StringType, true), StructField("bd_value", DecimalType(38,18), true), StructField("d_value",DoubleType,false)))
    
    // createDataFrame interprets the value into the schema's scale
    val df = spark.createDataFrame(rdd2, schema)
    
    df.show
    
    +-----+--------------------+--------+
    |group|            bd_value| d_value|
    +-----+--------------------+--------+
    |    C|123456.0000000000...|123456.0|
    +-----+--------------------+--------+