Search code examples
pythonnumpymatplotlibkriging

Get data array from object in Python


I'm using a library which produces 3 plots given an object k.

I need to figure the data points (x,y,z) that produced these plot, but the problem is that the plots comes from a function from k.

The library I'm using is pyKriging and this is their github repository.

A simplified version of their example code is:

import pyKriging  
from pyKriging.krige import kriging  
from pyKriging.samplingplan import samplingplan

sp = samplingplan(2)  
X = sp.optimallhc(20)

testfun = pyKriging.testfunctions().branin  
y = testfun(X)

k = kriging(X, y, testfunction=testfun, name='simple')   
k.train()
k.plot()

The full code, comments and output can be found here.

In summary, I'm trying to get the numpy array that produced these plots so I can create plots that follows my formatting styles.

I'm not knowledgeable about going into library codes in Python and I appreciate any help!


Solution

  • There is no single data array that produces the plot. Instead many arrays used for plotting are generated inside the kriging plot function.
    Changing the filled contours to line contours is of course not a style option. One therefore needs to use the code from the original plotting function.

    An option is to subclass kriging and implement a custom plot function (let's call it myplot). In this function, one can use contour instead of contourf. Naturally, it's also possible to change it completely to one's needs.

    import pyKriging  
    from pyKriging.krige import kriging  
    from pyKriging.samplingplan import samplingplan
    import numpy as np
    import matplotlib.pyplot as plt
    
    class MyKriging(kriging):
        def __init__(self,*args,**kwargs):
            kriging.__init__(self,*args,**kwargs)
        def myplot(self,labels=False, show=True, **kwargs):
            fig = plt.figure(figsize=(8,6))
            # Create a set of data to plot
            plotgrid = 61
            x = np.linspace(self.normRange[0][0], self.normRange[0][1], num=plotgrid)
            y = np.linspace(self.normRange[1][0], self.normRange[1][1], num=plotgrid)
            X, Y = np.meshgrid(x, y)
            # Predict based on the optimized results
            zs = np.array([self.predict([xi,yi]) for xi,yi in zip(np.ravel(X), np.ravel(Y))])
            Z = zs.reshape(X.shape)
            #Calculate errors
            zse = np.array([self.predict_var([xi,yi]) for xi,yi in zip(np.ravel(X), np.ravel(Y))])
            Ze = zse.reshape(X.shape)
    
            spx = (self.X[:,0] * (self.normRange[0][1] - self.normRange[0][0])) + self.normRange[0][0]
            spy = (self.X[:,1] * (self.normRange[1][1] - self.normRange[1][0])) + self.normRange[1][0]
    
            contour_levels = kwargs.get("levels", 25)
            ax = fig.add_subplot(222)
            CS = plt.contour(X,Y,Ze, contour_levels)
            plt.colorbar()
            plt.plot(spx, spy,'or')
    
            ax = fig.add_subplot(221)
            if self.testfunction:
                # Setup the truth function
                zt = self.testfunction( np.array(zip(np.ravel(X), np.ravel(Y))) )
                ZT = zt.reshape(X.shape)
                CS = plt.contour(X,Y,ZT,contour_levels ,colors='k',zorder=2, alpha=0)
    
            if self.testfunction:
                contour_levels = CS.levels
                delta = np.abs(contour_levels[0]-contour_levels[1])
                contour_levels = np.insert(contour_levels, 0, contour_levels[0]-delta)
                contour_levels = np.append(contour_levels, contour_levels[-1]+delta)
    
            CS = plt.contour(X,Y,Z,contour_levels,zorder=1)
            plt.plot(spx, spy,'or', zorder=3)
            plt.colorbar()
    
            ax = fig.add_subplot(212, projection='3d')
            ax.plot_surface(X, Y, Z, rstride=3, cstride=3, alpha=0.4)
            if self.testfunction:
                ax.plot_wireframe(X, Y, ZT, rstride=3, cstride=3)
            if show:
                plt.show()
    
    
    
    sp = samplingplan(2)  
    X = sp.optimallhc(20)
    
    testfun = pyKriging.testfunctions().branin  
    y = testfun(X)
    
    k = MyKriging(X, y, testfunction=testfun, name='simple')   
    k.train()
    k.myplot()
    

    enter image description here