Search code examples
scalaapache-sparkpipeline

Fitting pipeline and processing the data


I've got a file that contains text. What I want to do is to use a pipeline for tokenising the text, removing the stop-words and producing 2-grams.

What I've done so far:

Step 1: Read the file

val data = sparkSession.read.text("data.txt").toDF("text")

Step 2: Build the pipeline

val pipe1 = new Tokenizer().setInputCol("text").setOutputCol("words")
val pipe2 = new StopWordsRemover().setInputCol("words").setOutputCol("filtered")
val pipe3 = new NGram().setN(2).setInputCol("filtered").setOutputCol("ngrams")

val pipeline = new Pipeline().setStages(Array(pipe1, pipe2, pipe3))
val model = pipeline.fit(data)

I know that pipeline.fit(data) produces a PipelineModel however I don't know how to use a PipelineModel.

Any help would be much appreciated.


Solution

  • When you run the val model = pipeline.fit(data) code, all Estimator stages (ie: Machine Learning tasks like Classifications, Regressions, Clustering, etc) are fit to the data and a Transformer stage is created. You only have Transformer stages, since you're doing Feature creation in this pipeline.

    In order to execute your model, now consisting of just Transformer stages, you need to run val results = model.transform(data). This will execute each Transformer stage against your dataframe. Thus at the end of the model.transform(data) process, you will have a dataframe consisting of the original lines, the Tokenizer output, the StopWordsRemover output, and finally the NGram results.

    Discovering the top 5 ngrams after the feature creation is completed can be performed through a SparkSQL query. First explode the ngram column, then count groupby ngrams, ordering by the counted column in a descending fashion, and then performing a show(5). Alternatively, you could use a "LIMIT 5 method instead of show(5).

    As an aside, you should probably change your Object name to something that isn't a standard class name. Otherwise you're going to get an ambigious scope error.

    CODE:

    import org.apache.spark.sql.SparkSession
    import org.apache.spark.ml.feature.Tokenizer
    import org.apache.spark.sql.SparkSession._
    import org.apache.spark.sql.functions._
    import org.apache.spark.ml.feature.NGram
    import org.apache.spark.ml.feature.StopWordsRemover
    import org.apache.spark.ml.{Pipeline, PipelineModel}
    
    object NGramPipeline {
        def main() {
            val sparkSession = SparkSession.builder.appName("NGram Pipeline").getOrCreate()
    
            val sc = sparkSession.sparkContext
    
            val data = sparkSession.read.text("quangle.txt").toDF("text")
    
            val pipe1 = new Tokenizer().setInputCol("text").setOutputCol("words")
            val pipe2 = new StopWordsRemover().setInputCol("words").setOutputCol("filtered")
            val pipe3 = new NGram().setN(2).setInputCol("filtered").setOutputCol("ngrams")
    
            val pipeline = new Pipeline().setStages(Array(pipe1, pipe2, pipe3))
            val model = pipeline.fit(data)
    
            val results = model.transform(data)
    
            val explodedNGrams = results.withColumn("explNGrams", explode($"ngrams"))
            explodedNGrams.groupBy("explNGrams").agg(count("*") as "ngramCount").orderBy(desc("ngramCount")).show(10,false)
    
        }
    }
    NGramPipeline.main()
    



    OUTPUT:

    +-----------------+----------+
    |explNGrams       |ngramCount|
    +-----------------+----------+
    |quangle wangle   |9         |
    |wangle quee.     |4         |
    |'mr. quangle     |3         |
    |said, --         |2         |
    |wangle said      |2         |
    |crumpetty tree   |2         |
    |crumpetty tree,  |2         |
    |quangle wangle,  |2         |
    |crumpetty tree,--|2         |
    |blue babboon,    |2         |
    +-----------------+----------+
    only showing top 10 rows
    

    Notice that there is syntax (commas, dashes, etc) which are causing lines to be duplicated. When performing ngrams, it's often a good idea to filter our the syntax. You can typically do this with a regex.