Search code examples
scalaapache-sparkuser-defined-functions

Fitting LogisticRegression within a User Defined Fuction (UDF)


I've implemented the following code in Spark Scala:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification._

object Hello {
    def main(args: Array[String]) = {

          val getLabel1Probability = udf((param1: Double, labeledEntries: Seq[Array[Double]]) => {

            val trainingData = labeledEntries.map(entry => (org.apache.spark.ml.linalg.Vectors.dense(entry(0)), entry(1))).toList.toDF("features", "label")
            val regression = new LogisticRegression()
            val fittingModel = regression.fit(trainingData)

            val prediction = fittingModel.predictProbability(org.apache.spark.ml.linalg.Vectors.dense(param1))
            val probability = prediction.toArray(1)

            probability
          })

          val df = Seq((1.0, Seq(Array(1.0, 0), Array(2.0, 1))), (3.0, Seq(Array(1.0, 0), Array(2.0, 1)))).toDF("Param1", "LabeledEntries")

          val dfWithLabel1Probability = df.withColumn(
                "Label1Probability", getLabel1Probability(
                  $"Param1",
                  $"LabeledEntries"
                )
          )
          display(dfWithLabel1Probability)
    }
}

Hello.main(Array())  

When running it on Databricks' notebook multi-node cluster (DBR (Databricks) 13.2, Spark 3.4.0 and Scala 2.12.), dfWithLabel1Probability's display gets shown.

I have the following questions:

  • My understanding is that I should be getting a NullPointerException when creating the trainingData dataframe because _sqlContext is null within the udf. If so, why am I not getting it? Is it related to running it from Databricks' notebook? Is the behaviour non-deterministic?
  • If creating a dataframe is not allowed within a udf, how can I fit LogisticRegression with the data from a given dataframe's column? In the real example, I'm dealing with millions of rows for the dataframe so I would prefer avoiding the usage of Dataset's collect() to bring all those rows into driver's memory. Is there any alternative?

Thanks.


Solution

  • For the first question, if, instead, you run:

    val largedf = spark.range(100000).selectExpr("cast(id as double) Param1", "array(array(1.0, 0), array(2.0, 1)) LabeledEntries")
    
    val largedfWithLabel1Probability = largedf.withColumn(
        "Label1Probability", getLabel1Probability(
          $"Param1",
          $"LabeledEntries"
        )
    )
    
    display(largedfWithLabel1Probability)
    

    it will npe, as will range of 1, but using:

    (1 until 1000).map(a => (a.toDouble, Seq.. )).toDF..
    

    it will start processing at least. This is because toDF is using LocalRelation to build the data which is not sent to executors, whereas Range uses LeafNodes (executors) hence the exception.

    re the second question that could be worth putting as a separate top level question.