Search code examples
pythonmatplotlibpython-unittest

How can I write unit tests against code that uses matplotlib?


I'm working on a python (2.7) program that produce a lot of different matplotlib figure (the data are not random). I'm willing to implement some test (using unittest) to be sure that the generated figures are correct. For instance, I store the expected figure (data or image) in some place, I run my function and compare the result with the reference. Is there a way to do this ?


Solution

  • In my experience, image comparison tests end up bring more trouble than they are worth. This is especially the case if you want to run continuous integration across multiple systems (like TravisCI) that may have slightly different fonts or available drawing backends. It can be a lot of work to keep the tests passing even when the functions work perfectly correctly. Furthermore, testing this way requires keeping images in your git repository, which can quickly lead to repository bloat if you're changing the code often.

    A better approach in my opinion is to (1) assume matplotlib is going to actually draw the figure correctly, and (2) run numerical tests against the data returned by the plotting functions. (You can also always find this data inside the Axes object if you know where to look.)

    For example, say you want to test a simple function like this:

    import numpy as np
    import matplotlib.pyplot as plt
    def plot_square(x, y):
        y_squared = np.square(y)
        return plt.plot(x, y_squared)
    

    Your unit test might then look like

    def test_plot_square1():
        x, y = [0, 1, 2], [0, 1, 2]
        line, = plot_square(x, y)
        x_plot, y_plot = line.get_xydata().T
        np.testing.assert_array_equal(y_plot, np.square(y))
    

    Or, equivalently,

    def test_plot_square2():
        f, ax = plt.subplots()
        x, y = [0, 1, 2], [0, 1, 2]
        plot_square(x, y)
        x_plot, y_plot = ax.lines[0].get_xydata().T
        np.testing.assert_array_equal(y_plot, np.square(y))