Search code examples
pythonmatplotlibregressionloglog

How do I extend a linear regression plot


I've encountered a problem trying to fit a straight to linear part of my plot. To finish my plot I have to extend the red line as if it were a straight, so that it's intersection with at least x axis can be observed.

My code is:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

#data = pd.read_csv("LPPII_cw_2_1.csv")

#f = data["f [kHz]"] 
f = (1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 500)

#h21e = data["h21e [A/A]"]
h21e = (218., 215., 210., 200., 189., 175., 165., 150., 140., 129., 120., 69., 30.)

linearf = f[-3:]
linearh = h21e[-3:]

logA = np.log(linearf)
logB = np.log(linearh)

m, c = np.polyfit(logA, logB, 1, w=np.sqrt(linearh))
y_fit = np.exp(m*logA + c)

fig, ax = plt.subplots()

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_xlabel('f [kHz]')
ax.set_ylabel('h$_{21e}$ [A/A]')

ax.scatter(f, h21e, marker='.', color='k')
ax.plot(linearf, y_fit, color='r', linestyle='-')

plt.show()

and my plot looks like this:

plot image


Solution

  • You could add the maximum of the x-axis and append it at the end of linearf. Then calculate the curve, and draw it. The old y-limits need to be saved and reset, to prevent matplotlib to automatically extend these limits. Note that the x-lims only can be extracted after plotting the scatter plot.

    import matplotlib.pyplot as plt
    import numpy as np
    
    f = (1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 500)
    h21e = (218., 215., 210., 200., 189., 175., 165., 150., 140., 129., 120., 69., 30.)
    
    linearf = f[-3:]
    linearh = h21e[-3:]
    
    logA = np.log(linearf)
    logB = np.log(linearh)
    
    m, c = np.polyfit(logA, logB, 1, w=np.sqrt(linearh))
    
    fig, ax = plt.subplots()
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    
    ax.set_xlabel('f [kHz]')
    ax.set_ylabel('h$_{21e}$ [A/A]')
    
    ax.scatter(f, h21e, marker='.', color='k')
    
    linearf_ext = list(linearf) + [ax.get_xlim()[1]]
    logA = np.log(linearf_ext)
    y_fit = np.exp(m * logA + c)
    ymin, ymax = ax.get_ylim()
    ax.plot(linearf_ext, y_fit, color='r', linestyle='-')
    ax.set_ylim(ymin, ymax)
    plt.tight_layout()
    plt.show()
    

    extended linear regression