Search code examples
pythonapache-sparkpysparkapache-spark-sqlapache-spark-ml

Bigram counting in PySpark


I am trying to piece together a bigram counting program in PySpark that takes a text file and outputs the frequency of each proper bigram (two consecutive words in a sentence).

from pyspark.ml.feature import NGram

with use_spark_session("Bigrams") as spark:
    text_file = spark.sparkContext.textFile(text_path)
    sentences = text_file.flatMap(lambda line: line.split(".")) \
                        .filter(lambda line: len(line) > 0) \
                        .map(lambda line: (0, line.strip().split(" ")))  
    sentences_df = sentences.toDF(schema=["id", "words"])    
    ngram_df = NGram(n=2, inputCol="words", outputCol="bigrams").transform(sentences_df)

ngram_df.select("bigrams") now contains:

+--------------------+
|             bigrams|
+--------------------+
|[April is, is the...|
|[It is, is one, o...|
|[April always, al...|
|[April always, al...|
|[April's flowers,...|
|[Its birthstone, ...|
|[The meaning, mea...|
|[April comes, com...|
|[It also, also co...|
|[April begins, be...|
|[April ends, ends...|
|[In common, commo...|
|[In common, commo...|
|[In common, commo...|
|[In years, years ...|
|[In years, years ...|
+--------------------+

So there is the list of bigrams for each sentence. Now distinct bigrams need to be counted. How? Also, the whole code still seems unnecessarily verbose, so I'd be happy to see more concise solutions.


Solution

  • If you already go with RDD API you can just follow through

    bigrams = text_file.flatMap(lambda line: line.split(".")) \
                       .map(lambda line: line.strip().split(" ")) \
                       .flatMap(lambda xs: (tuple(x) for x in zip(xs, xs[1:])))
    
    bigrams.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
    

    Otherwise:

    from pyspark.sql.functions import explode
    
    ngram_df.select(explode("bigrams").alias("bigram")).groupBy("bigram").count()