Search code examples
pythonmatplotlibplotfigurefigures

matplotlib very slow in plotting


I have multiple functions in which I input an array or dict as well as a path as an argument, and the function will save a figure to the path of a particular path.

Trying to keep example as minimal as possible, but here are two functions:

def valueChartPatterns(dict,path):
    seen_values = Counter()

    for data in dict.itervalues():
        seen_values += Counter(data.values())

    seen_values = seen_values.most_common()
    seen_values_pct = map(itemgetter(1), tupleCounts2Percents(seen_values))
    seen_values_pct = ['{:.2%}'.format(item)for item in seen_values_pct]

    plt.figure()

    numberchart = plt.bar(range(len(seen_values)), map(itemgetter(1), seen_values), width=0.9,align='center')
    plt.xticks(range(len(seen_values)), map(itemgetter(0), seen_values))

    plt.title('Values in Pattern Dataset')
    plt.xlabel('Values in Data')
    plt.ylabel('Occurrences')

    plt.tick_params(axis='both', which='major', labelsize=6)
    plt.tick_params(axis='both', which='minor', labelsize=6)
    plt.tight_layout()

    plt.savefig(path)
    plt.clf()

def countryChartPatterns(dict,path):
    seen_countries = Counter()

    for data in dict.itervalues():
        seen_countries += Counter(data.keys())

    seen_countries = seen_countries.most_common()

    seen_countries_percentage = map(itemgetter(1), tupleCounts2Percents(seen_countries))
    seen_countries_percentage = ['{:.2%}'.format(item)for item in seen_countries_percentage]

    yvals = map(itemgetter(1), seen_countries)
    xvals = map(itemgetter(0), seen_countries)

    plt.figure()

    countrychart = plt.bar(range(len(seen_countries)), yvals, width=0.9,align='center')
    plt.xticks(range(len(seen_countries)), xvals)

    plt.title('Countries in Pattern Dataset')
    plt.xlabel('Countries in Data')
    plt.ylabel('Occurrences')

    plt.tick_params(axis='both', which='major', labelsize=6)
    plt.tick_params(axis='both', which='minor', labelsize=6)
    plt.tight_layout()

    plt.savefig(path)
    plt.clf()

A very minimal example dict is, but the actual dict contains 56000 values:

dict = {"a": {"Germany": 20006.0, "United Kingdom": 20016.571428571428}, "b": {"Chad": 13000.0, "South Africa": 3000000.0},"c":{"Chad": 200061.0, "South Africa": 3000000.0}
    }

And in my script, I call:

if __name__ == "__main__":

    plt.close('all')

    print "Starting pattern charting...\n"

    countryChartPatterns(dict,'newPatternCountries.png'))

    valueChartPatterns(dict,'newPatternValues.png'))

Note, I load import matplotlib.pyplot as plt.

When running this script in PyCharm, I get Starting pattern charting... in my console but the functions take super long to plot.

What am I doing wrong? Should I be using a histogram instead of a bar plot as this should achieve the same aim of giving the number of occurrences of countries/values? Can I change my GUI backend somehow? Any advice welcome.


Solution

  • This is the test that I mentioned in the comments above, resulting in:

    Elapsed pre-processing = 13.79 s
    Elapsed plotting = 0.17 s
    Pre-processing / plotting = 83.3654562565
    

    Test script:

    import matplotlib.pylab as plt
    from collections import Counter
    from operator import itemgetter
    import time
    
    def countryChartPatterns(dict,path):
        # pre-processing -------------------
        t0 = time.time()
    
        seen_countries = Counter()
    
        for data in dict.itervalues():
            seen_countries += Counter(data.keys())
    
        seen_countries = seen_countries.most_common()
    
        yvals = map(itemgetter(1), seen_countries)
        xvals = map(itemgetter(0), seen_countries)
    
        dt1 = time.time() - t0
        print("Elapsed pre-processing = {0:.2f} s".format(dt1))
    
        t0 = time.time()
    
        # plotting -------------------
        plt.figure()
    
        countrychart = plt.bar(range(len(seen_countries)), yvals, width=0.9,align='center')
        plt.xticks(range(len(seen_countries)), xvals)
    
        plt.title('Countries in Pattern Dataset')
        plt.xlabel('Countries in Data')
        plt.ylabel('Occurrences')
    
        plt.tick_params(axis='both', which='major', labelsize=6)
        plt.tick_params(axis='both', which='minor', labelsize=6)
        plt.tight_layout()
    
        plt.savefig(path)
        plt.clf()
    
        dt2 = time.time() - t0 
        print("Elapsed plotting = {0:.2f} s".format(dt2))
        print("Pre-processing / plotting = {}".format(dt1/dt2))
    
    if __name__ == "__main__":
        import random as rd
        import numpy as np
    
        countries = ["United States of America", "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua & Deps", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan"]
    
        def item():
            return {rd.choice(countries): np.random.randint(1e3), rd.choice(countries): np.random.randint(1e3)}
        dict = {}
        for i in range(1000000):
            dict[i] = item()
    
        print("Starting pattern charting...")
    
        countryChartPatterns(dict,'newPatternCountries.png')