I'd like to take a Spark LDA Model's term indices from the .describeTopics()
output and match them to the appropriate term in the count vectorizer's vocabulary. Here is the point of friction:
terms = ['fuzzy', 'wuzzy', 'bear', 'seashells', 'chuck', 'wood', 'big', 'black', 'woodchuck', 'sell', 'hair', 'rug', 'sat', 'seashore', 'much', 'sells', 'many']
+-----+--------------+------------------------------------------------------------------------------------+
|topic|termIndices |termWeights |
+-----+--------------+------------------------------------------------------------------------------------+
|0 |[6, 16, 4, 13]|[0.07759153026889895, 0.07456018590515792, 0.06590443764744822, 0.06529979905841589]|
|1 |[0, 8, 1, 12] |[0.08460924078697102, 0.06935412981755526, 0.06803462316387827, 0.06505150660960128]|
|2 |[8, 14, 5, 3] |[0.07473487407268035, 0.06999332154185754, 0.06923579179113146, 0.06673236538997057]|
|3 |[2, 15, 10, 9]|[0.07990489772171691, 0.07352818255574894, 0.0725564301639141, 0.0705481456715216] |
+-----+--------------+------------------------------------------------------------------------------------+
My desired output would be the above Dataframe with an array column, terms
, containing the appropriate terms based on the termIndices
.
Here is the code to set the problem up:
from pyspark.sql.functions import udf
from pyspark.ml.feature import CountVectorizer, StopWordsRemover, RegexTokenizer, NGram
from pyspark.ml import Pipeline
from pyspark.ml.clustering import LDA
import numpy as np
text_df = spark.createDataFrame([
(0, "How much WOOD could a woodchuck chuck if a woodchuck could chuck wood?"),
(1, "She sells SEASHELLS by the seashore. How many seashells did she sell?"),
(2, "Fuzzy Wuzzy was a bear. Fuzzy Wuzzy had no hair. Fuzzy Wuzzy wasn't very fuzzy, was he?"),
(3, "A BIG BLACK bear sat on a big black rug.")
], ['id','text'])
# Arguments to be passed to functions
input_data = text_df
text_col = "text"
n_topics = 4
n_gram = 1
min_doc_freq = 0.1
max_doc_freq = 0.9
# Model pipeline; RegexTokenizer allows us to tokenize on the regex, forgoing an explicit symbol removal step
tokenizer = RegexTokenizer(inputCol=text_col, outputCol="tokens", pattern=r"[\s{2,}&;!\.\(\)-<>/,\?]+")
stopwords = StopWordsRemover(inputCol="tokens", outputCol="tokens_clean")
ngrams = NGram(inputCol = "tokens_clean", n=n_gram, outputCol = "ngram")
count_vec = CountVectorizer(inputCol="ngram", outputCol="features", minDF = min_doc_freq, maxDF = max_doc_freq)
lda = LDA(k=n_topics, seed=477)
pipeline = Pipeline(stages=[tokenizer, stopwords, ngrams, count_vec, lda])
# Fitting and transforming
model = pipeline.fit(input_data)
Here is what I have tried:
# This section to experiment with matching vocabulary indices to words in terms
topics = model.stages[-1].describeTopics(4) # This yields the desired topic table above
terms = model.stages[-2].vocabulary # This yields the vocabulary above
# Defining a UDF to try and match indices to terms
@udf
def indices_to_terms(indices):
terms_subset = [terms[index] for index in indices]
return(np.array(terms_subset))
# Attempting to use UDF to add a terms column
topics = (
topics
.withColumn("terms", indices_to_terms(F.col("termIndices")))
)
topics.show()
I don't actually get an error when I run this code, but it doesn't show me anything. Perhaps the UDF isn't the right approach. How might I match the index in termIndices
to the model vocabulary in terms
and make it an array column to work with?
The solution was just defining my UDF more carefully. The following code solved my problem.
# This section to experiment with matching vocabulary indices to words in terms
topics = model.stages[-1].describeTopics(4)
terms = model.stages[-2].vocabulary
# Pandas function for matching indices to terms in vocabulary
def indices_to_terms(indices, terms=terms):
terms_subset = [terms[index] for index in indices]
return terms_subset
# Defining Spark UDF from above function
udf_indices_to_terms = F.udf(indices_to_terms, ArrayType(StringType()))
topics = (
topics
.withColumn("terms", udf_indices_to_terms(F.col("termIndices")))
)