Search code examples
pythonfunctionwith-statementmonkeypatching

How to monkey patch matplotlib's axis safely (Not affecting future calls)?


I have a function which visualizes matrix elements using bar3d. I was trying to remove margins at the bounding limits of z-axis. I found this answer(first one) which uses monkey patching. So my code looks like this:

from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.axis3d import Axis

# function which applies monkey patching
def _remove_margins():
    """
    Removes margins about z=0 and improves the style
    """
    if not hasattr(Axis, "_get_coord_info_old"):
        def _get_coord_info_new(self, renderer):
            mins, maxs, centers, deltas, tc, highs = \
                self._get_coord_info_old(renderer)
            mins += deltas/4
            maxs -= deltas/4
            return mins, maxs, centers, deltas, tc, highs
        Axis._get_coord_info_old = Axis._get_coord_info
        Axis._get_coord_info = _get_coord_info_new


# function which visualizes the matrix
# ✅ this function should be affected by monkey patching
def visualize_matrix(M, figsize, ... ):
    _remove_margins()
    
    fig = plt.figure(figsize=figsize)
    ax  = Axes3D(fig)
    ax.bar3d(...)
    .
    .
    .
    return fig, ax

# another function that uses Axes3D
# ⛔️ this function should not be affected by monkey patching
def visualize_sphere(...):
    fig = plt.figure(figsize=figsize)
    ax  = Axes3D(fig)
    .
    .
    .
    return fig, ax

Problem:

In future calls of Axes3D (e.g. using visualize_sphere function) the changes made by monkey patching still remains.

Question:

How to monkey patch safely to solve the problem?


Solution

  • I changed the monkey patch to make changes only to the instance, not the class. Apply patch_axis to ax.xaxis, ax.yaxis and ax.zaxis after creating the ax.

    import matplotlib.pyplot as plt
    import numpy as np
    
    
    def patch_axis(axis):
        def _get_coord_info_new(renderer):
            mins, maxs, centers, deltas, tc, highs = _get_coord_info_old(renderer)
            mins += deltas / 4
            maxs -= deltas / 4
            return mins, maxs, centers, deltas, tc, highs
    
        _get_coord_info_old = axis._get_coord_info
        axis._get_coord_info = _get_coord_info_new
    
    
    def test():
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.margins(0)
        for c, z in zip(['r', 'g', 'b', 'y'], [30, 20, 10, 0]):
            xs = np.arange(20)
            ys = np.random.rand(20)
    
            # You can provide either a single color or an array. To demonstrate this,
            # the first bar of each set will be colored cyan.
            cs = [c] * len(xs)
            cs[0] = 'c'
            ax.bar(xs, ys, zs=z, zdir='y', color=cs, alpha=0.8)
    
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
    
        return fig, ax
    
    # without the margin
    fig, ax = test()
    patch_axis(ax.xaxis)
    patch_axis(ax.yaxis)
    patch_axis(ax.zaxis)
    fig.savefig("test1.png")
    
    # with the margin
    fig, ax = test()
    fig.savefig("test2.png")
    

    Methods are just functions defined in the namespace of classes, and accessing instance method implicitly fill in the self parameters (see Class instances in https://docs.python.org/3/reference/datamodel.html). You can override the bound method by assigning functions without self parameters in the namespaces the instantiation adds.

    >>> class C:
    ...     def f(self): return 1
    ... 
    >>> C.f
    <function C.f at 0x7f36a7eb53a0>
    >>> c = C()
    >>> c.f
    <bound method C.f of <__main__.C object at 0x7f36a7f48eb0>>
    >>> c.f()
    1
    >>> c.f = lambda: 2
    >>> c.f()
    2