Search code examples
pythonmatplotlibplotaxisfigure

How to replace a Matplotlib axis with a curvilinear axis


There's a nice example of a curvilinear axis grid in the documentation here. This assumes you have an empty figure and then you create the curvilinear axis using:

ax1 = Subplot(fig, 1, 2, 1, grid_helper=grid_helper)

However, I am writing a function that is given an existing axis and I want the function to replace this axis (ax) with the new curvilinear axis.

Here is the start of my function so far:

import matplotlib.pyplot as plt
from mpl_toolkits.axisartist import SubplotHost
from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear
import mpl_toolkits.axisartist.angle_helper as angle_helper
from matplotlib.projections import PolarAxes
from matplotlib.transforms import Affine2D

def sgrid(ax=None):
    """
    Adds an s-plane grid of constant damping factors and natural
    frequencies to a plot.  If ax is not specified, the current
    figure axis is used.

    Parameters
    ----------
    ax : matplotlib axis object
        If not passed, uses gca() to get current axis.

    Returns
    -------
    ax : matplotlib axis
    """

    grid_helper = GridHelperCurveLinear(
        ... some stuff ...
    )

    if ax is None:
        # Get the current axis or create a new figure with one
        ax = plt.gca()
    fig = ax.figure

    # TODO: How can we change the axis if it is already created?
    ax = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)

    ... (code that adds more stuff to ax) ...

    return ax

Also, I'm not sure I understand the arguments to SubplotHost. Are these the initialization arguments of a new Axis or what?

UPDATE

The goal here is to emulate the way Pandas.Series.plot functions work. The desired use-cases are things like these:

H = tf([2, 5, 1],[1, 2, 3])
rlocus(H)
sgrid()
plt.show()

or

>>> fig, axes = plt.subplots(2, 1)
>>> rlocus(H1, ax=axes[1])
>>> rlocus(H2, ax=axes[2])
>>> for ax in axes:
>>>     sgrid(ax=ax)  # Later we might want to add ax.zgrid()
>>> plt.show()

The order of sgrid and rlocus should ideally be as above as this is similar to the MATLAB functions we are emulating and also the plt.grid() method which these are similar to.


Solution

  • grid_helper is a keyword for mpl_toolkits.axisartist.axislines.Axes. https://matplotlib.org/api/_as_gen/mpl_toolkits.axisartist.axislines.html

    You might want to check if the provided axis is a subclass of this Axes extension and recreate the axis otherwise:

    from mpl_toolkits.axisartist import SubplotHost,axislines
    
    ...
    
    def sgrid(ax1=None):
        ...
        if ax1 is None:
            ax1=plt.gca()
        fig=ax1.figure
        if not isinstance(ax1,axislines.Axes):
            subargs=ax1.get_geometry()
            fig.delaxes(ax1)
            ax2 = SubplotHost(fig, *subargs, grid_helper=grid_helper)
            fig.add_subplot(ax2)
        else:
           ax2=ax1
        ...
    
        return ax2
    

    Be aware that the caller of sgrid() still references ax1 which is no longer part of the figure. There might be a more sophisticated replacement of the axes references necessary than to just delete and recreate.