Search code examples
pythonsympysymbolic-math

Taking the cross product of an unknown matrix with an unknown function in sympy


I would like sympy to confirm that, given the equation:

$$ r'(t) = A \times r(t) $$

(i.e. "the derivative of an unknown function r is the cross product of an unknown matrix A with r"), it follows that:

$$ r''(t) = A \times r'(t) $$

(i.e. "the second derivative of r is the cross product of A with the first derivative of r").

From the documentation it seems like I want to use MatrixSymbol for A, but MatrixSymbol doesn't define cross:

from sympy import *
from sympy.abc import *
r = Function('r')(t) 
A = MatrixSymbol('A', 4, 4)  # dummy dimensions

Derivative(A.cross(r))

gives me:

AttributeError                            Traceback (most recent call last)
<ipython-input-52-4c8dc7c142cf> in <module>
      4 A = MatrixSymbol('A', 4, 4)
      5 
----> 6 Derivative(A.cross(r))

AttributeError: 'MatrixSymbol' object has no attribute 'cross'

What's the right way to do this?


Solution

  • SymPy's vector class is completely separate from the Matrix class which can be confusing if you're used to thinking of a vector as a particular kind of matrix: https://docs.sympy.org/latest/modules/vector/index.html

    I'll demonstrate how to do this with the vector class. This can be done more compactly but I'm spelling it out in detail:

    In [23]: from sympy.vector import CoordSys3D 
        ...: N = CoordSys3D('N')                                                                                                                   
    
    In [24]: a1, a2, a3 = symbols('a1:4')                                                                                                          
    
    In [25]: r1, r2, r3 = [ri(t) for ri in symbols('r1:4', cls=Function)]                                                                          
    
    In [26]: A = a1*i + a2*j + a3*k                                                                                                                
    
    In [27]: r = r1*i + r2*j + r3*k                                                                                                                
    
    In [28]: A                                                                                                                                     
    Out[28]: a1*N.i + a2*N.j + a3*N.k
    
    In [29]: r                                                                                                                                     
    Out[29]: (r1(t))*N.i + (r2(t))*N.j + (r3(t))*N.k
    
    In [30]: (A.cross(r.diff(t))).diff(t) == A.cross(r.diff(t, 2))                                                                                 
    Out[30]: True