Search code examples
pysparknlpapache-spark-mllibcountvectorizerkeyword-extraction

Get topN keywords with PySpark CountVectorizer


I want to extract keywords using pyspark.ml.feature.CountVectorizer.
My input Spark dataframe looks as following:

id text
1 sun, mars, solar system, solar system, mars, solar system, venus, solar system, mars
2 planet, moon, milky way, milky way, moon, milky way, sun, milky way, mars, star

I applied the following pipeline:

# Convert string to array
input_df = input_df.withColumn("text_array", split("text", ','))

cv_text = CountVectorizer() \
    .setInputCol("text_array") \
    .setOutputCol("cv_text")

cv_model = cv_text.fit(input_df)
cv_result = cv_model.transform(input_df)

cv_result.show()

Output:

id text text_array cv_text
1 sun, mars, solar system, .. [sun, mars, solar system, .. (9,[1,2,4,7],[3.0,4.0,1.0,1.0])
2 planet, moon, milky way, .. [planet, moon, milky way, .. (9,[0,1,3,5,6,8],[4.0,1.0,2.0,1.0,1.0,1.0])

How can I now get for each id (for each row) top n keywords (top 2, for example)?
Expected output:

id text text_array cv_text keywords
1 sun, mars, solar system, .. [sun, mars, solar system, .. (9,[1,2,4,7],[3.0,4.0,1.0,1.0]) solar system, mars
2 planet, moon, milky way, .. [planet, moon, milky way, .. (9,[0,1,3,5,6,8],[4.0,1.0,2.0,1.0,1.0,1.0]) milky way, moon

I will be very grateful for any advice, docs, examples!


Solution

  • I haven't found a way to work with Sparse Vectors besides very few operations in the pyspark.ml.feature module so for something like taking the top n values I would say a UDF is the way to go.

    The function below uses np.argpartition to find the top n values of vector values and return their indices which conveniently we can put in the vector indices to get the values.

    import numpy as np
    from pyspark.sql.functions import udf
    
    @udf("array<integer>")
    def get_top_n(v, n):
        top_n_indices = np.argpartition(v.values, -n)[-n:]
        return [int(x) for x in v.indices[top_n_indices]]
    

    The values returned are the vocabulary index and not the actual word. If the vocabulary is not that big we can put it as an array column of its own and transform the idx to the actual word.

    from pyspark.sql.functions import col, transform
    
    voc = spark.createDataFrame([(cv_model.vocabulary,)], ["voc"])
    
    cv_result \
    .withColumn("top_2", get_top_n("cv_text", lit(2))) \
    .crossJoin(voc) \
    .withColumn("top_2_parsed", transform("top_2", lambda v: col("voc")[v])) \
    .show() 
    
    +---+--------------------+--------------------+--------------------+------+--------------------+--------------------+
    | id|                text|          text_array|             cv_text| top_2|                 voc|        top_2_parsed|
    +---+--------------------+--------------------+--------------------+------+--------------------+--------------------+
    |  1|sun, mars, solar ...|[sun,  mars,  sol...|(9,[1,2,4,7],[4.0...|[2, 1]|[ milky way,  sol...|[ mars,  solar sy...|
    |  2|planet, moon, mil...|[planet,  moon,  ...|(9,[0,2,3,5,6,8],...|[3, 0]|[ milky way,  sol...| [ moon,  milky way]|
    +---+--------------------+--------------------+--------------------+------+--------------------+--------------------+
    

    I'm not sure I feel that good with the solution above, probably not scalable. That being said, if you don't actually need the CountVectorizer , there is a combination of standard functions we can do on the input_df to simply get the top_n words of every sentence.

    from pyspark.sql.functions import explode, row_number, desc, col
    from pyspark.sql.window import Window
    
    input_df \
    .select("id", explode("text_array").alias("word")) \
    .groupBy("id", "word") \
    .count() \
    .withColumn("rn", row_number().over(Window.partitionBy("id").orderBy(desc("count")))) \
    .filter(col("rn") <= 2) \
    .show()
    
    +---+-------------+-----+---+
    | id|         word|count| rn|
    +---+-------------+-----+---+
    |  1| solar system|    4|  1|
    |  1|         mars|    3|  2|
    |  2|    milky way|    4|  1|
    |  2|         moon|    2|  2|
    +---+-------------+-----+---+