Search code examples
apache-sparkparallel-processingpysparktext-classification

Encode sentence as sequence model with Spark


I am doing text classification and I use pyspark.ml.feature.Tokenizer to tokenize the text. However CountVectorizer transforms the tokenized list of words to bag of words model, not the sequence model.

Assume that we have the following DataFrame with columns id and texts:

 id | texts
----|----------
 0  | Array("a", "b", "c")
 1  | Array("a", "b", "b", "c", "a")
each row in texts is a document of type Array[String]. Invoking fit of CountVectorizer produces a CountVectorizerModel with vocabulary (a, b, c). Then the output column “vector” after transformation contains:

 id | texts                           | vector
----|---------------------------------|---------------
 0  | Array("a", "b", "c")            | (3,[0,1,2],[1.0,1.0,1.0])
 1  | Array("a", "b", "b", "c", "a")  | (3,[0,1,2],[2.0,2.0,1.0])

What I want here is (for the row 1)

Array("a", "b", "b", "c", "a")  | [0, 1, 1, 2, 0]

So is there anyway that I can write custom function to run encoding in parallel? Or is there any other library that can do in parallel other than using spark?


Solution

  • You could use StringIndexer and explode:

    df = spark_session.createDataFrame([
        Row(id=0, texts=["a", "b", "c"]),
        Row(id=1, texts=["a", "b", "b", "c", "a"])
    ])
    
    data = df.select("id", explode("texts").alias("texts"))
    indexer = StringIndexer(inputCol="texts", outputCol="indexed", stringOrderType="alphabetAsc")
    indexer\
        .fit(data)\
        .transform(data)\
        .groupBy("id")\
        .agg(collect_list("texts").alias("texts"), collect_list("indexed").alias("vector"))\
        .show(20, False)
    

    Output:

    +---+---------------+-------------------------+
    |id |texts          |vector                   |
    +---+---------------+-------------------------+
    |0  |[a, b, c]      |[0.0, 1.0, 2.0]          |
    |1  |[a, b, b, c, a]|[0.0, 1.0, 1.0, 2.0, 0.0]|
    +---+---------------+-------------------------+