Search code examples
scalaapache-sparkapache-spark-mllibcountvectorizer

Spark: FlatMap and CountVectorizer pipeline


I working on the pipeline and try to split the column value before passing it to CountVectorizer.

For this purpose I made a custom Transformer.

class FlatMapTransformer(override val uid: String)
  extends Transformer {
  /**
   * Param for input column name.
   * @group param
   */
  final val inputCol = new Param[String](this, "inputCol", "The input column")
  final def getInputCol: String = $(inputCol)

  /**
   * Param for output column name.
   * @group param
   */
  final val outputCol = new Param[String](this, "outputCol", "The output column")
  final def getOutputCol: String = $(outputCol)

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("FlatMapTransformer"))

  private val flatMap: String => Seq[String] = { input: String =>
    input.split(",")
  }

  override def copy(extra: ParamMap): SplitString = defaultCopy(extra)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

  override def transformSchema(schema: StructType): StructType = {
    val dataType = schema($(inputCol)).dataType
    require(
      dataType.isInstanceOf[StringType],
      s"Input column must be of type StringType but got ${dataType}")
    val inputFields = schema.fields
    require(
      !inputFields.exists(_.name == $(outputCol)),
      s"Output column ${$(outputCol)} already exists.")

    DataTypes.createStructType(
      Array(
        DataTypes.createStructField($(outputCol), DataTypes.StringType, false)))
  }
}

The code seems legit, but when I try to chain it with other operation the problem occurs. Here is my pipeline:

val train = reader.readTrainingData()

val cat_features = getFeaturesByType(taskConfig, "categorical")
val num_features = getFeaturesByType(taskConfig, "numeric")
val cat_ohe_features = getFeaturesByType(taskConfig, "categorical", Some("ohe"))
val cat_features_string_index = cat_features.
  filter { feature: String => !cat_ohe_features.contains(feature) }

val catIndexer = cat_features_string_index.map {
  feature =>
    new StringIndexer()
      .setInputCol(feature)
      .setOutputCol(feature + "_index")
      .setHandleInvalid("keep")
}

    val flatMapper = cat_ohe_features.map {
      feature =>
        new FlatMapTransformer()
          .setInputCol(feature)
          .setOutputCol(feature + "_transformed")
    }

    val countVectorizer = cat_ohe_features.map {
      feature =>

        new CountVectorizer()
          .setInputCol(feature + "_transformed")
          .setOutputCol(feature + "_vectorized")
          .setVocabSize(10)
    }


// val countVectorizer = cat_ohe_features.map {
//   feature =>
//
//     val flatMapper = new FlatMapTransformer()
//       .setInputCol(feature)
//       .setOutputCol(feature + "_transformed")
// 
//     new CountVectorizer()
//       .setInputCol(flatMapper.getOutputCol)
//       .setOutputCol(feature + "_vectorized")
//       .setVocabSize(10)
// }

val cat_features_index = cat_features_string_index.map {
  (feature: String) => feature + "_index"
}

val count_vectorized_index = cat_ohe_features.map {
  (feature: String) => feature + "_vectorized"
}

val catFeatureAssembler = new VectorAssembler()
  .setInputCols(cat_features_index)
  .setOutputCol("cat_features")

val oheFeatureAssembler = new VectorAssembler()
  .setInputCols(count_vectorized_index)
  .setOutputCol("cat_ohe_features")

val numFeatureAssembler = new VectorAssembler()
  .setInputCols(num_features)
  .setOutputCol("num_features")

val featureAssembler = new VectorAssembler()
  .setInputCols(Array("cat_features", "num_features", "cat_ohe_features_vectorized"))
  .setOutputCol("features")

val pipelineStages = catIndexer ++ flatMapper ++ countVectorizer ++
  Array(
    catFeatureAssembler,
    oheFeatureAssembler,
    numFeatureAssembler,
    featureAssembler)

val pipeline = new Pipeline().setStages(pipelineStages)
pipeline.fit(dataset = train)

Running this code, I receive an error: java.lang.IllegalArgumentException: Field "my_ohe_field_trasformed" does not exist.

[info]  java.lang.IllegalArgumentException: Field "from_expdelv_areas_transformed" does not exist.

[info]  at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info]  at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)

[info]  at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)

[info]  at scala.collection.AbstractMap.getOrElse(Map.scala:59)

[info]  at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)

[info]  at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:56)

[info]  at org.apache.spark.ml.feature.CountVectorizerParams$class.validateAndTransformSchema(CountVectorizer.scala:75)

[info]  at org.apache.spark.ml.feature.CountVectorizer.validateAndTransformSchema(CountVectorizer.scala:123)

[info]  at org.apache.spark.ml.feature.CountVectorizer.transformSchema(CountVectorizer.scala:188)

When I uncomment the stringSplitter and countVectorizer the error is raised in my Transformer

java.lang.IllegalArgumentException: Field "my_ohe_field" does not exist. at val dataType = schema($(inputCol)).dataType

Result of calling pipeline.getStages:

strIdx_3c2630a738f0

strIdx_0d76d55d4200

FlatMapTransformer_fd8595c2969c

FlatMapTransformer_2e9a7af0b0fa

cntVec_c2ef31f00181

cntVec_68a78eca06c9

vecAssembler_a81dd9f43d56

vecAssembler_b647d348f0a0

vecAssembler_b5065a22d5c8

vecAssembler_d9176b8bb593

I might follow the wrong way. Any comments are appreciated.


Solution

  • Your FlatMapTransformer #transform is incorrect, your kind of dropping/ignoring all other columns when you select only on outputCol

    please modify your method to -

     override def transform(dataset: Dataset[_]): DataFrame = {
         val flatMapUdf = udf(flatMap)
        dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
      }
    

    Also, Modify your transformSchema to check input column first before checking its datatype-

     override def transformSchema(schema: StructType): StructType = {
    require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe")
    //... rest as it is
    }
    

    Update-1 based on comments

    1. PLease modify the copy method (Though it's not the cause for exception you facing)-
    override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)
    
    1. please note that the CountVectorizer takes the column having columns of type ArrayType(StringType, true/false) and since the FlatMapTransformer output columns becomes the input of CountVectorizer, you need to make sure output column of FlatMapTransformer must be of ArrayType(StringType, true/false). I think, this is not the case, your code today is as following-
      override def transform(dataset: Dataset[_]): DataFrame = {
        val flatMapUdf = udf(flatMap)
        dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
      }
    
    

    The explode functions converts the array<string> to string, so the output of the transformer becomes StringType. you may wanted to change this code to-

      override def transform(dataset: Dataset[_]): DataFrame = {
        val flatMapUdf = udf(flatMap)
        dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol))))
      }
    
    
    1. modify transformSchema method to output ArrayType(StringType)
     override def transformSchema(schema: StructType): StructType = {
          val dataType = schema($(inputCol)).dataType
          require(
            dataType.isInstanceOf[StringType],
            s"Input column must be of type StringType but got ${dataType}")
          val inputFields = schema.fields
          require(
            !inputFields.exists(_.name == $(outputCol)),
            s"Output column ${$(outputCol)} already exists.")
    
          schema.add($(outputCol), ArrayType(StringType))
        }
    
    1. change vector assembler to this-
    val featureAssembler = new VectorAssembler()
          .setInputCols(Array("cat_features", "num_features", "cat_ohe_features"))
          .setOutputCol("features")
    

    I tried to execute your pipeline on dummy dataframe, it worked well. Please refer this gist for full code.