Search code examples
pythonsympyequationdifferential-equationssymbolic-math

Sympy system of differential equations


Problem

I'm making a symbolic solver for mechanical links previous question for details

Right now I can get sympy.solve to solve big systems of linear equations, but I'm having an hard time making the solver resolve the partial differential equations. The solver can solve them, but it gets confused in when and what it should solve and doesn't output something useful.

Minimal Code:

#Try to solve Y=Z X=dY(Z)^3/dZ
import sympy as lib_sympy

def bad_derivative_wrong( in_x : lib_sympy.Symbol, in_y : lib_sympy.Symbol, in_z : lib_sympy.Symbol ):
    l_equation = []
    l_equation.append( lib_sympy.Eq( in_y, in_z ) )
    l_equation.append( lib_sympy.Eq( in_x, lib_sympy.Derivative(in_y*in_y*in_y, in_z, evaluate = True) ) )
    solution = lib_sympy.solve( l_equation, (in_x,in_y,), exclude = () )
    return solution

def bad_derivative_unhelpful( in_x : lib_sympy.Symbol, in_y : lib_sympy.Symbol, in_z : lib_sympy.Symbol ):
    l_equation = []
    l_equation.append( lib_sympy.Eq( in_y, in_z ) )
    l_equation.append( lib_sympy.Eq( in_x, lib_sympy.Derivative(in_y*in_y*in_y, in_z, evaluate = False) ) )
    solution = lib_sympy.solve( l_equation, (in_x,in_y,), exclude = () )
    return solution

def good_derivative( in_x : lib_sympy.Symbol, in_y : lib_sympy.Symbol, in_z : lib_sympy.Symbol ):
    l_equation = []
    l_equation.append( lib_sympy.Eq( in_y, in_z ) )
    l_equation.append( lib_sympy.Eq( in_x, lib_sympy.Derivative(in_z*in_z*in_z, in_z, evaluate = True) ) )
    #what happens here is that Derivative has already solved the derivative, it's not a symbol
    solution = lib_sympy.solve( l_equation, (in_x,in_y,), exclude = () )
    #lib_sympy.dsolve
    return solution

if __name__ == '__main__':
    #n_x = lib_sympy.symbols('X', cls=lib_sympy.Function)
    n_x = lib_sympy.symbols('X')
    n_y = lib_sympy.Symbol('Y')
    n_z = lib_sympy.Symbol('Z')
    print("Wrong Derivative: ", bad_derivative_wrong( n_x, n_y, n_z ) )
    print("Unhelpful Derivative: ", bad_derivative_unhelpful( n_x, n_y, n_z ) )
    print("Good Derivative: ", good_derivative( n_x, n_y, n_z ) )

Output:

Wrong Derivative:  {Y: Z, X: 0}
Unhelpful Derivative:  {Y: Z, X: Derivative(Y**3, Z)}
Good Derivative:  {Y: Z, X: 3*Z**2}

Question:

I need a way to add partial derivative symbols to my equations in a way that the solver is happy to solve.

E.g. the speed is the derivative of the position over time. E.g. the sensitivity of the position in respect to the angle is related to precision and force.


Solution

  • It took a while to find the answer, but I found this solution that helped out. Essentially, there is a (currently) open git issue that shares seek and focus functions. Solving your equations using the focus function will give you the x and y solutions as functions of z only. Since you aren't importing the entire sympy library (i.e. you aren't doing from sympy import *), which is how I also run sympy, we need to add import lines for Eq and Tuple to the functions, which I did below.

    I also made a change that allows you to pass optionally pass simplify to focus. If you set simplify=True, it will simplify the results. This is what you want, otherwise, it will return {y: z, x: Derivative(z**3, z)}.

    seek:

    def seek(eqs, do, sol=[], strict=True):
        from sympy.solvers.solvers import _invert as f
        from sympy.core.compatibility import ordered
        from sympy import Eq
        while do and eqs:
            for x in do:
                for e in eqs:
                    if not isinstance(e, Eq):
                        continue
                    i, d = f(e.lhs - e.rhs, x)
                    if d != x:
                        continue
                    break
                else:
                    if strict:
                        assert None  # no eq could be solved for x
                    continue
                sol.append((d, i))
                eqs.remove(e)
                break
            do.remove(x)
            if not strict:
                do.extend(i.free_symbols)
                do = list(ordered(do))
            for _ in range(len(eqs)):
                if not isinstance(eqs[_], Eq):
                    continue
                # avoid dividing by zero
                ln, ld = eqs[_].lhs.as_numer_denom()
                rn, rd = eqs[_].rhs.as_numer_denom()
                eqs[_] = Eq(ln*rd, rn*ld).xreplace({x: i})
                if eqs[_] == False:
                    raise ValueError('inconsistency detected')
        return sol
    

    focus:

    def focus(eqs, *syms, **kwargs):
        """Given Equality instances in ``eqs``, solve for symbols in
        ``syms`` and resolve as many of the free symbols in the solutions
        as possible. When ``evaluate=True`` a dictionary with keys being
        ``syms`` is returned, otherwise a list of identified symbols
        leading to the desired symbols is given.
    
        Examples
        ========
        >>> focus((Eq(a, b), Eq(b + 2, c)), a)
        {a: c - 2}
        >>> focus((Eq(a, b), Eq(b + 2, c)), a, evaluate=False)
        [(b, c - 2), (a, b)]
        """
        from sympy.solvers.solvers import _invert as f
        from sympy.core.compatibility import ordered
        from sympy import Eq, Tuple
        evaluate = kwargs.get('evaluate', True)
        assert all(isinstance(i, Eq) for i in eqs)
        sol = []
        free = Tuple(*eqs).free_symbols
        do = set(syms) & free
        if not do:
            return sol
        eqs = list(eqs)
        seek(eqs, do, sol)
        assert not do
        for x, i in sol:
            do |= i.free_symbols
        do = list(ordered(do))  # make it canonical
        seek(eqs, do, sol, strict=False)
        if evaluate:
            while len(sol) > len(syms):
                x, s = sol.pop()
                for i in range(len(sol)):
                    sol[i] = (sol[i][0], sol[i][1].xreplace({x: s}))
            for i in reversed(range(1, len(syms))):
                x, s = sol[i]
                for j in range(i):
                    y, t = sol[j]
                    sol[j] = y, f(y - t.xreplace({x: s}), y)[0]
        simplify = kwargs.get("simplify", False)
        if simplify:
            for i in range(len(sol)):
                sol[i] = (sol[i][0], sol[i][1].simplify())
        if evaluate:
            sol = dict(sol)
        else:
            sol = list(reversed(sol))
        return sol
    

    Your example:

    import sympy as sp
    x, y, z = sp.symbols("x y z")
    
    l_equation = []
    l_equation.append(sp.Eq(y, z))
    l_equation.append(sp.Eq(x, sp.Derivative(y**3, z)))
    
    solution = focus(l_equation, x, y, simplify=True)
    print(solution)    # {y: z, x: 3*z**2}
    

    As indicated by my explanation above, I give credit to those who originally developed this solution, I've only made some minor changes to suit the needs of this problem.