Search code examples
sympyderivative

Why does substitution into a sympy derivative only partly work


I am trying to use symbolic derivatives and the chain rule in sympy, prior to substituting explicit functions. For example, I just tried this in https://www.sympy.org/en/shell.html, expecting to get the same result for both print lines:

from sympy import diff, Function, Symbol

c = Symbol('c')

a = Function('a')(c)

b = Function('b')(c)

V = a*b*c

print(diff(V,c))

V = Function('V')

print((diff(V(a,b,c),c).subs(V(a,b,c), a*b*c)).doit())

c*a(c)*Derivative(b(c), c) + c*b(c)*Derivative(a(c), c) + a(c)*b(c)
c*a(c)*Derivative(b(c), c) + c*b(c)*Derivative(a(c), c) + Subs(Derivative(V(a(c), b(c), _xi_3), _xi_3), _xi_3, c)

In the last line, the substitution was not done for the last term. Is this a bug or did I misunderstand something?


Solution

  • The last print statement shows that SymPy's subs is unable to perform the requested substitution inside the Sub class (yeah, it's a bit confusing). That's because subs is looking for V(a(c), b(c), c), whereas inside the Sub class the function has been automatically replaced by V(a(c), b(c), _xi_3).

    So, we need something more powerful to perform the substitution. We can use the replace method and wild symbols, which are used for pattern matching operations, like the one we are going to do now:

    from sympy import *
    c = Symbol('c')
    a = Function('a')(c)
    b = Function('b')(c)
    V = Function('V')
    # create a wild symbol: in its basic form, it will match everything
    w = Wild("w")
    print(V(a,b,c).diff(c).replace(V(a,b,w), a*b*w).doit())
    # out: c*a(c)*Derivative(b(c), c) + c*b(c)*Derivative(a(c), c) + a(c)*b(c)
    

    Let's dive into the last command:

    • V(a,b,c).diff(c) does the symbolic differentiation. It introduced the Sub class which replaced V(a, b, c) with V(a, b, _xi_3).
    • We use the replace method to perform the substitution: we are asking it to search for patterns of the form V(a, b, w) where w can be anything. We want to replace this pattern with a * b * w. So, in this case SymPy will find two matches:
      1. V(a, b, c) which will be replaced by a * b * c.
      2. V(a, b, _xi_3) which will be replaced by a * b * _xi_3.
    • Finally, apply the doit method to perform the computation.