Search code examples
pythonfunctiondifferentiationautomatic-differentiation

How can I differentiate a function of a function?


I am trying to differentiate z(x) w.r.t. x using the ad library, where I know y(x) and z(y). If I cannot analytically find z(x), how can I perform the differentiation? In other words, I am trying to avoid the chain rule calculation as shown below:

from ad import gh

def y(x):
    return 2*x

def z(y):
    return 3*y

dzdy,hy = gh(z)
dydx,hz = gh(y)

x0 = 0 # does not matter for this example
dydx_x0 = dydx(x0)

y0 = y(x0)
dzdy_y0 = dzdy(y0)

dzdx_x0 = dzdy_y0[0] * dydx_x0[0]

print(dzdx_x0)    # dz/dx = dz/dy*dy/dx = 3*2 = 6

Solution

  • def z_of_x(x):
        return z(y(x))
    
    gradient, hessian = gh(z_of_x)
    

    Just define a function to compute z in terms of x, and apply automatic differentiation as usual.