Search code examples
scalaapache-sparkapache-spark-sqluser-defined-functionsapache-spark-ml

How to create a custom Transformer from a UDF?


I was trying to create and save a Pipeline with custom stages. I need to add a column to my DataFrame by using a UDF. Therefore, I was wondering if it was possible to convert a UDF or a similar action into a Transformer?

My custom UDF looks like this and I'd like to learn how to do it using an UDF as a custom Transformer.

def getFeatures(n: String) = {
    val NUMBER_FEATURES = 4  
    val name = n.split(" +")(0).toLowerCase
    ((1 to NUMBER_FEATURES)
         .filter(size => size <= name.length)
         .map(size => name.substring(name.length - size)))
} 

val tokenizeUDF = sqlContext.udf.register("tokenize", (name: String) => getFeatures(name))

Solution

  • It is not a fully featured solution but your can start with something like this:

    import org.apache.spark.ml.{UnaryTransformer}
    import org.apache.spark.ml.util.Identifiable
    import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
    
    class NGramTokenizer(override val uid: String)
      extends UnaryTransformer[String, Seq[String], NGramTokenizer]  {
    
      def this() = this(Identifiable.randomUID("ngramtokenizer"))
    
      override protected def createTransformFunc: String => Seq[String] = {
        getFeatures _
      }
    
      override protected def validateInputType(inputType: DataType): Unit = {
        require(inputType == StringType)
      }
    
      override protected def outputDataType: DataType = {
        new ArrayType(StringType, true)
      }
    }
    

    Quick check:

    val df = Seq((1L, "abcdef"), (2L, "foobar")).toDF("k", "v")
    val transformer = new NGramTokenizer().setInputCol("v").setOutputCol("vs")
    
    transformer.transform(df).show
    // +---+------+------------------+
    // |  k|     v|                vs|
    // +---+------+------------------+
    // |  1|abcdef|[f, ef, def, cdef]|
    // |  2|foobar|[r, ar, bar, obar]|
    // +---+------+------------------+
    

    You can even try to generalize it to something like this:

    import org.apache.spark.sql.catalyst.ScalaReflection.schemaFor
    import scala.reflect.runtime.universe._
    
    class UnaryUDFTransformer[T : TypeTag, U : TypeTag](
      override val uid: String,
      f: T => U
    ) extends UnaryTransformer[T, U, UnaryUDFTransformer[T, U]]  {
    
      override protected def createTransformFunc: T => U = f
    
      override protected def validateInputType(inputType: DataType): Unit = 
        require(inputType == schemaFor[T].dataType)
    
      override protected def outputDataType: DataType = schemaFor[U].dataType
    }
    
    val transformer = new UnaryUDFTransformer("featurize", getFeatures)
      .setInputCol("v")
      .setOutputCol("vs")
    

    If you want to use UDF not the wrapped function you'll have to extend Transformer directly and override transform method. Unfortunately majority of the useful classes is private so it can be rather tricky.

    Alternatively you can register UDF:

    spark.udf.register("getFeatures", getFeatures _)
    

    and use SQLTransformer

    import org.apache.spark.ml.feature.SQLTransformer
    
    val transformer = new SQLTransformer()
      .setStatement("SELECT *, getFeatures(v) AS vs FROM __THIS__")
    
    transformer.transform(df).show
    // +---+------+------------------+
    // |  k|     v|                vs|
    // +---+------+------------------+
    // |  1|abcdef|[f, ef, def, cdef]|
    // |  2|foobar|[r, ar, bar, obar]|
    // +---+------+------------------+