Search code examples
pythonmathderivativechaincalculus

Chain Rule in plain python


I see that the question is already answered for sympy, but I am trying to write an implementation of chain rule for a educative purposes on a toy project with no third party libraries.

Basically chain rule is k'(x) = f'(g(x)) * g'(x) where k(x) = f(g(x))

I have the following functions:

def g(x):
    return x**3 + 2

def f(x):
    return x**2 + 7

def de(fn, x, step):
    t1 = fn(x)
    t2 = fn(x+step)
    return (t2 - t1) / step

def chain(x):
    return f(g(x))

def de_chain(x, step):
    d_g = de(g, x, step)
    gres = g(x)
    d_f_g = de(f, gres, step)
    return d_g * d_f_g

The problem is when I evaluate de_chain and de(chain) for x=1.2 and step=2.6, I get de(chain) = 205.5446... and de_chain = 1238.6639....

Something is wrong here, because same approach is applied to addition and subtraction like in k'(x) = g'(x) + f'(x) where k(x) = g(x) + f(x) the result was very very close. What am I doing wrong ?

Thanks


Solution

  • Your code looks right. The problem is that making derivative estimates with just a single difference isn't, in general, terribly accurate and here your step size is quite large. Remember the derivative is your de function but it's the limit of that as the step goes to 0.

    Consider just your g(x). Its actual derivative at x=1 is 3*x^2 = 3 * 1^2 = 3. But with your step size of 2.6 you'd get an estimate of 4.6 which is pretty far off the mark.

    You can read about more accurate methods for estimating derivatives here: https://en.wikipedia.org/wiki/Numerical_differentiation