Search code examples
scalaapache-sparkapache-spark-sqlspark-structured-streaming

How to compute statistics on a streaming dataframe for different type of columns in a single query?


I have a streaming dataframe having three columns time, col1,col2.

+-----------------------+-------------------+--------------------+
|time                   |col1               |col2                |
+-----------------------+-------------------+--------------------+
|2018-01-10 15:27:21.289|0.4988615628926717 |0.1926744113882285  |
|2018-01-10 15:27:22.289|0.5430687338123434 |0.17084552928040175 |
|2018-01-10 15:27:23.289|0.20527770821641478|0.2221980020202523  |
|2018-01-10 15:27:24.289|0.130852802747647  |0.5213147910202641  |
+-----------------------+-------------------+--------------------+

The datatype of col1 and col2 is variable. It could be a string or numeric datatype. So I have to calculate statistics for each column. For string column, calculate only valid count and invalid count. For timestamp column, calculate only min & max. For numeric type column, calculate min, max, average and mean. I have to compute all statistics in a single query. Right now, I have computed with three queries separately for every type of column.


Solution

  • Enumerate cases you want and select. For example, if stream is defined as:

    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.Column
    
    val schema = StructType(Seq(
      StructField("v", TimestampType),
      StructField("x", IntegerType),
      StructField("y", StringType),
      StructField("z", DecimalType(10, 2))
    ))
    
    val df = spark.readStream.schema(schema).format("csv").load("/tmp/foo")
    

    The result would be

    val stats = df.select(df.dtypes.flatMap {
      case (c, "StringType") => 
        Seq(count(c) as s"valid_${c}", count("*") - count(c) as s"invalid_${c}")
      case (c, t) if Seq("TimestampType", "DateType") contains t => 
        Seq(min(c), max(c))
      case (c, t) if (Seq("FloatType", "DoubleType", "IntegerType") contains t) || t.startsWith("DecimalType") => 
        Seq(min(c), max(c), avg(c), stddev(c))
      case _ => Seq.empty[Column]
    }: _*)
    
    // root
    //  |-- min(v): timestamp (nullable = true)
    //  |-- max(v): timestamp (nullable = true)
    //  |-- min(x): integer (nullable = true)
    //  |-- max(x): integer (nullable = true)
    //  |-- avg(x): double (nullable = true)
    //  |-- stddev_samp(x): double (nullable = true)
    //  |-- valid_y: long (nullable = false)
    //  |-- invalid_y: long (nullable = false)
    //  |-- min(z): decimal(10,2) (nullable = true)
    //  |-- max(z): decimal(10,2) (nullable = true)
    //  |-- avg(z): decimal(14,6) (nullable = true)
    //  |-- stddev_samp(z): double (nullable = true)