Search code examples
pythonmatplotlibmachine-learninggriddistortion

How to plot using matplotlib (python) colah's deformed grid?


I need to create a visualization in Python just like colah's did on his site. However, I could not find any distortion to grid on matplotlib to perform exactly like he did here. Pls, help me if you can.

This is the plot I need to perform:
colah's vector field space distortion


Solution

  • I would guess the image is produced by adding some gaussian function to the grid.

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection
    
    def plot_grid(x,y, ax=None, **kwargs):
        ax = ax or plt.gca()
        segs1 = np.stack((x,y), axis=2)
        segs2 = segs1.transpose(1,0,2)
        ax.add_collection(LineCollection(segs1, **kwargs))
        ax.add_collection(LineCollection(segs2, **kwargs))
        ax.autoscale()
    
    
    f = lambda x,y : ( x+0.8*np.exp(-x**2-y**2),y )
    
    fig, ax = plt.subplots()
    
    grid_x,grid_y = np.meshgrid(np.linspace(-3,3,20),np.linspace(-3,3,20))
    plot_grid(grid_x,grid_y, ax=ax,  color="lightgrey")
    
    distx, disty = f(grid_x,grid_y)
    plot_grid(distx, disty, ax=ax, color="C0")
    
    plt.show()
    

    enter image description here