Search code examples
pythonmatplotliblinear-regressionseabornlogarithm

How to scale the x and y axis equally by log in Seaborn?


I want to create a regplot with a linear regression in Seaborn and scale both axes equally by log, such that the regression stays a straight line.

An example:

import matplotlib.pyplot as plt
import seaborn as sns

some_x=[0,1,2,3,4,5,6,7]
some_y=[3,5,4,7,7,9,9,10]

ax = sns.regplot(x=some_x, y=some_y, order=1)
plt.ylim(0, 12)
plt.xlim(0, 12)
plt.show()

What I get:

Linear regression

If I scale the x and y axis by log, I would expect the regression to stay a straight line. What I tried:

import matplotlib.pyplot as plt
import seaborn as sns

some_x=[0,1,2,3,4,5,6,7]
some_y=[3,5,4,7,7,9,9,10]

ax = sns.regplot(x=some_x, y=some_y, order=1)
ax.set_yscale('log')
ax.set_xscale('log')
plt.ylim(0, 12)
plt.xlim(0, 12)
plt.show()

How it looks:

Linear regression becomes a curve


Solution

  • The problem is that you are fitting to your data on a regular scale but later you are transforming the axes to log scale. So linear fit will no longer be linear on a log scale.

    What you need instead is to transform your data to log scale (base 10) and then perform a linear regression. Your data is currently a list. It would be easy to transform your data to log scale if you convert your list to NumPy array because then you can make use of vectorised operation.

    Caution: One of your x-entry is 0 for which log is not defined. You will encounter a warning there.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    
    some_x=np.array([0,1,2,3,4,5,6,7])
    some_y=np.array([3,5,4,7,7,9,9,10])
    
    ax = sns.regplot(x=np.log10(some_x), y=np.log10(some_y), order=1)
    

    Solution using NumPy polyfit where you exclude x=0 data point from the fit

    import matplotlib.pyplot as plt
    import numpy as np
    
    some_x=np.log10(np.array([0,1,2,3,4,5,6,7]))
    some_y=np.log10(np.array([3,5,4,7,7,9,9,10]))
    
    fit = np.poly1d(np.polyfit(some_x[1:], some_y[1:], 1))
    
    plt.plot(some_x, some_y, 'ko')
    plt.plot(some_x, fit(some_x), '-k')
    

    enter image description here