Context: I am not sure if this is the right site to post this question, please let me know if it isn't. My aim is to solve the coupled differential equations given in the code for the Alpha Centauri star system.
Code:
#Import scipy, numpy and mpmath
import scipy as sci
import numpy as np
import mpmath as mp
#Import matplotlib and associated modules for 3D and animations
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation
#Import decimal for better precision
from decimal import *
getcontext().prec = 10000
#Define universal gravitation constant
G=Decimal(6.67408e-11) #N-m2/kg2
#Reference quantities
m_nd=Decimal(1.989e+30) #kg #mass of the sun
r_nd=Decimal(5.326e+12) #m #distance between stars in Alpha Centauri
v_nd=Decimal(30000) #m/s #relative velocity of earth around the sun
t_nd=Decimal(79.91*365*24*3600*0.51) #s #orbital period of Alpha Centauri
#Net constants
K1=G*t_nd*m_nd/(r_nd**2*v_nd)
K2=v_nd*t_nd/r_nd
#Define masses
m1=Decimal(1.1) #Alpha Centauri A
m2=Decimal(0.907) #Alpha Centauri B
m3=Decimal(1.0) #Third Star
#Define initial position vectors
r1=np.array([Decimal(-0.5),Decimal(0),Decimal(0)])
r2=np.array([Decimal(0.5),Decimal(0),Decimal(0)])
r3=np.array([Decimal(0),Decimal(1),Decimal(0)])
#Find Centre of Mass
r_com=(m1*r1+m2*r2+m3*r3)/(m1+m2+m3)
#Define initial velocities
v1=np.array([Decimal(0.01),Decimal(0.01),Decimal(0)])
v2=np.array([Decimal(-0.05),Decimal(0),Decimal(-0.1)])
v3=np.array([Decimal(0),Decimal(-0.01),Decimal(0)])
#Find velocity of COM
v_com=(m1*v1+m2*v2+m3*v3)/(m1+m2+m3)#Define initial velocities
def ThreeBodyEquations(w,t,G,m1,m2,m3):
r1=w[:3]
r2=w[3:6]
r3=w[6:9]
v1=w[9:12]
v2=w[12:15]
v3=w[15:18]
r12=sci.linalg.norm(r2-r1)
r13=sci.linalg.norm(r3-r1)
r23=sci.linalg.norm(r3-r2)
dv1bydt=K1*m2*(r2-r1)/r12**3+K1*m3*(r3-r1)/r13**3+(61**2)*r1
dv2bydt=K1*m1*(r1-r2)/r12**3+K1*m3*(r3-r2)/r23**3+(61**2)*r2
dv3bydt=K1*m1*(r1-r3)/r13**3+K1*m2*(r2-r3)/r23**3+(61**2)*r3
dr1bydt=K2*v1
dr2bydt=K2*v2
dr3bydt=K2*v3
r12_derivs=sci.concatenate((dr1bydt,dr2bydt))
r_derivs=sci.concatenate((r12_derivs,dr3bydt))
v12_derivs=sci.concatenate((dv1bydt,dv2bydt))
v_derivs=sci.concatenate((v12_derivs,dv3bydt))
derivs=sci.concatenate((r_derivs,v_derivs))
return derivs
#Package initial parameters
init_params=np.array([r1,r2,r3,v1,v2,v3]) #Initial parameters
init_params=init_params.flatten() #Flatten to make 1D array
time_span=sci.linspace(0,20,500) #20 orbital periods and 500 points
#Run the ODE solver
three_body_sol=mp.odefun(ThreeBodyEquations,time_span,init_params,time_span)
r1_sol=three_body_sol[:,:3]
r2_sol=three_body_sol[:,3:6]
r3_sol=three_body_sol[:,6:9]
#Create figure
fig=plt.figure(figsize=(15,15))
#Create 3D axes
ax=fig.add_subplot(111,projection="3d")
#Plot the orbits
ax.plot(r1_sol[:,0],r1_sol[:,1],r1_sol[:,2],color="darkblue")
ax.plot(r2_sol[:,0],r2_sol[:,1],r2_sol[:,2],color="tab:red")
#Plot the final positions of the stars
ax.scatter(r1_sol[-1,0],r1_sol[-1,1],r1_sol[-1,2],color="darkblue",marker="o",s=100,label="Alpha Centauri A")
ax.scatter(r2_sol[-1,0],r2_sol[-1,1],r2_sol[-1,2],color="tab:red",marker="o",s=100,label="Alpha Centauri B")
#Add a few more bells and whistles
ax.set_xlabel("x-coordinate",fontsize=14)
ax.set_ylabel("y-coordinate",fontsize=14)
ax.set_zlabel("z-coordinate",fontsize=14)
ax.set_title("Visualization of orbits of stars in a two-body system\n",fontsize=14)
ax.legend(loc="upper left",fontsize=14)
To my surprise, I am getting this error
ValueError Traceback (most recent call last)
<ipython-input-11-8ecff918f44e> in <module>
88 #Run the ODE solver
89 import scipy.integrate
---> 90 three_body_sol=mp.odefun(ThreeBodyEquations,time_span,init_params,time_span)
91
92 r1_sol=three_body_sol[:,:3]
/usr/local/lib/python3.8/dist-packages/mpmath/calculus/odes.py in odefun(ctx, F, x0, y0, tol, degree, method, verbose)
228
229 """
--> 230 if tol:
231 tol_prec = int(-ctx.log(tol, 2))+10
232 else:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Now I am speculating that Python wants me to use a.any()
or a.all()
when entering the initial parameters but np.any(time_span)
and np.any(init_params)
also throws an error. Can someone please tell me what is going wrong and how do I rectify this? Thank you in advance