Search code examples
pythonmatplotlibpython-typingpylance

Precise type annotating array (numpy.ndarray) of matplotlib Axes from plt.subplots()


I wanted to have no errors while using VSCode Pylance type checker.

How to type the axs correctly in the following code:

import matplotlib.pyplot as plt
fig, axs = plt.subplots(2, 2)

In the image below, you can see that Pylance on VSCode is detecting an error.

enter image description here


Solution

  • It turns out that strongly typing the axs variable is not straightforward at all and requires to understant well how to type np.ndarray.

    See this question and this question for more details.

    The simplest and most powerful solution is to wrap numpy.ndarray with ' characters, in order to avoid the infamous TypeError: 'numpy._DTypeMeta' object is not subscriptable when Python tries to interpret the [] in the expression.

    An example:

    import matplotlib.pyplot as plt
    import numpy as np
    import numpy.typing as npt
    import seaborn as sns
    from typing import cast, Type, Sequence
    import typing 
    
    sns.set() 
    
    # Some example data to display
    x = np.linspace(0, 2 * np.pi, 400)
    y = np.sin(x ** 2)
    
    fig, axs = plt.subplots(
        2, 2, 
        figsize=(12, 10) # set graph size
    )
    
    # typechecking operation
    NDArrayOfAxes: typing.TypeAlias = 'np.ndarray[Sequence[Sequence[plt.Axes]], np.dtype[np.object_]]'
    axs = cast(np.ndarray, axs)
    
    axs[0, 0].plot(x, y)
    axs[0, 0].set_title("main")
    axs[1, 0].plot(x, y**2)
    axs[1, 0].set_title("shares x with main")
    axs[1, 0].sharex(axs[0, 0])
    axs[0, 1].plot(x + 1, y + 1)
    axs[0, 1].set_title("unrelated")
    axs[1, 1].plot(x + 2, y + 2)
    axs[1, 1].set_title("also unrelated")
    fig.tight_layout()
    

    Which is well detected by Pylance and runs correctly:

    enter image description here