Search code examples
pythonmatplotlibldaword-cloud

Side by side Wordclouds in matplotlib


I'm using the package WordCloud to display words generated by scikit LDA (Latent Dirichlet Allocation). For each topic generated by LDA I'll have a chart. I want to be able to plot all charts in a grid to allow visualization side by side. Essentially I have a function that takes an LDA model as input, along with the LDA topic I want to visualize and then plots a wordcloud:

from wordcloud import WordCloud
import matplotlib.pyplot as plt
SEED=0

def topicWordCloud(model, topicNumber, WCmaxWords,WCwidth, WCheight):
    topic = model.components_[topicNumber]
    tupleList = [(tf_feature_names[i],int(topic[i]/topic.sum()*10000)) for i in range(len(topic))]
    wordcloud = WordCloud(width=WCwidth, height=WCheight, max_words=WCmaxWords, random_state=42).generate_from_frequencies(tupleList)
    plt.figure( figsize=(20,10) )
    plt.imshow(wordcloud)
    plt.axis("off")

topicWordCloud(model=lda, topicNumber=2, WCmaxWords=100,WCwidth=800, WCheight=600)

How do I loop through all my topics (n_topics) to visualize all the charts in a grid? I was thinking something along the lines of:

fig = plt.figure()
for i in range(n_topics):
    plt.subplot(2,1,i+1) 
    #something here

Solution

  • Return the wordcloud from your function, then call topicWordCloud from within your for loop. Then, use imshow on the Axes that you create with fig.add_subplot. For example, something like this:

    def topicWordCloud(model, topicNumber, WCmaxWords,WCwidth, WCheight):
        topic = model.components_[topicNumber]
        tupleList = [(tf_feature_names[i],int(topic[i]/topic.sum()*10000)) for i in range(len(topic))]
        wordcloud = WordCloud(width=WCwidth, height=WCheight, max_words=WCmaxWords, random_state=42).generate_from_frequencies(tupleList)
        return wordcloud
    
    fig = plt.figure()
    for i in range(n_topics):
        ax = fig.add_subplot(2,1,i+1)
        wordcloud = topicWordCloud(...)
    
        ax.imshow(wordcloud)
        ax.axis('off')