Search code examples
pythonodeanympmath

Use a.any() or a.all() error in the code to solve Coupled ODE


Context: I am not sure if this is the right site to post this question, please let me know if it isn't. My aim is to solve the coupled differential equations given in the code for the Alpha Centauri star system.

Code:

#Import scipy, numpy and mpmath
import scipy as sci
import numpy as np
import mpmath as mp
#Import matplotlib and associated modules for 3D and animations
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation
#Import decimal for better precision
from decimal import *
getcontext().prec = 10000

#Define universal gravitation constant
G=Decimal(6.67408e-11) #N-m2/kg2
#Reference quantities
m_nd=Decimal(1.989e+30) #kg #mass of the sun
r_nd=Decimal(5.326e+12) #m #distance between stars in Alpha Centauri
v_nd=Decimal(30000) #m/s #relative velocity of earth around the sun
t_nd=Decimal(79.91*365*24*3600*0.51) #s #orbital period of Alpha Centauri
#Net constants
K1=G*t_nd*m_nd/(r_nd**2*v_nd)
K2=v_nd*t_nd/r_nd

#Define masses
m1=Decimal(1.1) #Alpha Centauri A
m2=Decimal(0.907) #Alpha Centauri B 
m3=Decimal(1.0) #Third Star

#Define initial position vectors    
r1=np.array([Decimal(-0.5),Decimal(0),Decimal(0)])
r2=np.array([Decimal(0.5),Decimal(0),Decimal(0)])
r3=np.array([Decimal(0),Decimal(1),Decimal(0)])

#Find Centre of Mass
r_com=(m1*r1+m2*r2+m3*r3)/(m1+m2+m3)
#Define initial velocities
v1=np.array([Decimal(0.01),Decimal(0.01),Decimal(0)])
v2=np.array([Decimal(-0.05),Decimal(0),Decimal(-0.1)])
v3=np.array([Decimal(0),Decimal(-0.01),Decimal(0)])

#Find velocity of COM
v_com=(m1*v1+m2*v2+m3*v3)/(m1+m2+m3)#Define initial velocities

def ThreeBodyEquations(w,t,G,m1,m2,m3):
    r1=w[:3]
    r2=w[3:6]
    r3=w[6:9]
    v1=w[9:12]
    v2=w[12:15]
    v3=w[15:18]
    r12=sci.linalg.norm(r2-r1)
    r13=sci.linalg.norm(r3-r1)
    r23=sci.linalg.norm(r3-r2)
    
    dv1bydt=K1*m2*(r2-r1)/r12**3+K1*m3*(r3-r1)/r13**3+(61**2)*r1
    dv2bydt=K1*m1*(r1-r2)/r12**3+K1*m3*(r3-r2)/r23**3+(61**2)*r2
    dv3bydt=K1*m1*(r1-r3)/r13**3+K1*m2*(r2-r3)/r23**3+(61**2)*r3
    dr1bydt=K2*v1
    dr2bydt=K2*v2
    dr3bydt=K2*v3
    r12_derivs=sci.concatenate((dr1bydt,dr2bydt))
    r_derivs=sci.concatenate((r12_derivs,dr3bydt))
    v12_derivs=sci.concatenate((dv1bydt,dv2bydt))
    v_derivs=sci.concatenate((v12_derivs,dv3bydt))
    derivs=sci.concatenate((r_derivs,v_derivs))
    return derivs

#Package initial parameters
init_params=np.array([r1,r2,r3,v1,v2,v3]) #Initial parameters
init_params=init_params.flatten() #Flatten to make 1D array
time_span=sci.linspace(0,20,500) #20 orbital periods and 500 points

#Run the ODE solver
three_body_sol=mp.odefun(ThreeBodyEquations,time_span,init_params,time_span)

r1_sol=three_body_sol[:,:3]
r2_sol=three_body_sol[:,3:6]
r3_sol=three_body_sol[:,6:9]

#Create figure
fig=plt.figure(figsize=(15,15))
#Create 3D axes
ax=fig.add_subplot(111,projection="3d")
#Plot the orbits
ax.plot(r1_sol[:,0],r1_sol[:,1],r1_sol[:,2],color="darkblue")
ax.plot(r2_sol[:,0],r2_sol[:,1],r2_sol[:,2],color="tab:red")
#Plot the final positions of the stars
ax.scatter(r1_sol[-1,0],r1_sol[-1,1],r1_sol[-1,2],color="darkblue",marker="o",s=100,label="Alpha Centauri A")
ax.scatter(r2_sol[-1,0],r2_sol[-1,1],r2_sol[-1,2],color="tab:red",marker="o",s=100,label="Alpha Centauri B")
#Add a few more bells and whistles
ax.set_xlabel("x-coordinate",fontsize=14)
ax.set_ylabel("y-coordinate",fontsize=14)
ax.set_zlabel("z-coordinate",fontsize=14)
ax.set_title("Visualization of orbits of stars in a two-body system\n",fontsize=14)
ax.legend(loc="upper left",fontsize=14)

To my surprise, I am getting this error

ValueError                                Traceback (most recent call last)
<ipython-input-11-8ecff918f44e> in <module>
     88 #Run the ODE solver
     89 import scipy.integrate
---> 90 three_body_sol=mp.odefun(ThreeBodyEquations,time_span,init_params,time_span)
     91 
     92 r1_sol=three_body_sol[:,:3]
/usr/local/lib/python3.8/dist-packages/mpmath/calculus/odes.py in odefun(ctx, F, x0, y0, tol, degree, method, verbose)
    228 
    229     """
--> 230     if tol:
    231         tol_prec = int(-ctx.log(tol, 2))+10
    232     else:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Now I am speculating that Python wants me to use a.any() or a.all() when entering the initial parameters but np.any(time_span) and np.any(init_params) also throws an error. Can someone please tell me what is going wrong and how do I rectify this? Thank you in advance


Solution

  • You need to read and understand the documentation. mpmath.odefun is fundamentally different from scipy.integrate.odeint. mpmath.odefun provides a dynamical solution object more similar to the scipy.integrate.ode stepper class in that at its call it does not compute (much), it simply initializes an object. The actual solution data in form of a "dense output" is computed and stored in subsequent calls to the returned object. The time range of that data is extended as necessary.

    How that is done can be seen in the documentation examples. In your case this could be done as

    three_body_fun=mp.odefun(ThreeBodyEquations,time_span[0],init_params, tol=1e-4, degree=5)
    three_body_sol = [ three_body_fun(t) for t in time_span]
    

    For that to start to work you will need to change ThreeBodyEquations to only have the parameters t,w in that order. It is not necessary to pass the constants as parameters, they are taken from the global context. Note that w and its slices are simple lists, you need to convert to the vector format to apply vector subtraction.

    There should be no problem to use numpy/scipy arrays to store multi-precision data and perform vector operations. However the norm function will use a floating point square root function that invalidates all the effort with the multi-precision implementation, so better code your own euclidean norm or use the mpmath variant, if available.

    I did not test how well Decimal and mpmath work together, I just replace the import of the former with

    Decimal = lambda x: mp.mpf(x)
    

    You might want to set the mpmath precision to something reasonable, using 25 decimal places could be sensible, using 80 could be done for reference values, using 10000 bit or around 3000 decimal places will unreasonably increase the computation time, especially as your constants and inputs do not have that accuracy.